diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 8bf36c8bd..dcbe99333 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -103,6 +103,7 @@ namespace bolt { RArrow, RArrowAlt, LetKeyword, + FnKeyword, MutKeyword, PubKeyword, TypeKeyword, @@ -160,7 +161,8 @@ namespace bolt { Parameter, LetBlockBody, LetExprBody, - LetDeclaration, + FunctionDeclaration, + VariableDeclaration, RecordDeclarationField, RecordDeclaration, VariantDeclaration, @@ -549,6 +551,20 @@ namespace bolt { }; + class FnKeyword : public Token { + public: + + inline FnKeyword(TextLoc StartLoc): + Token(NodeKind::FnKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::FnKeyword; + } + + }; + class MutKeyword : public Token { public: @@ -1661,7 +1677,7 @@ namespace bolt { }; - class LetDeclaration : public Node { + class FunctionDeclaration : public Node { Scope* TheScope = nullptr; @@ -1672,26 +1688,23 @@ namespace bolt { class Type* Ty; class PubKeyword* PubKeyword; - class LetKeyword* LetKeyword; - class MutKeyword* MutKeyword; - class Pattern* Pattern; + class FnKeyword* FnKeyword; + class Identifier* Name; std::vector Params; class TypeAssert* TypeAssert; LetBody* Body; - LetDeclaration( + FunctionDeclaration( class PubKeyword* PubKeyword, - class LetKeyword* LetKeywod, - class MutKeyword* MutKeyword, - class Pattern* Pattern, + class FnKeyword* FnKeyword, + class Identifier* Name, std::vector Params, class TypeAssert* TypeAssert, LetBody* Body - ): Node(NodeKind::LetDeclaration), + ): Node(NodeKind::FunctionDeclaration), PubKeyword(PubKeyword), - LetKeyword(LetKeywod), - MutKeyword(MutKeyword), - Pattern(Pattern), + FnKeyword(FnKeyword), + Name(Name), Params(Params), TypeAssert(TypeAssert), Body(Body) {} @@ -1703,14 +1716,6 @@ namespace bolt { return TheScope; } - bool isFunc() const noexcept { - return !Params.empty(); - } - - bool isVar() const noexcept { - return !isFunc(); - } - bool isInstance() const noexcept { return Parent->getKind() == NodeKind::InstanceDeclaration; } @@ -1723,11 +1728,47 @@ namespace bolt { Token* getLastToken() const override; static bool classof(const Node* N) { - return N->getKind() == NodeKind::LetDeclaration; + return N->getKind() == NodeKind::FunctionDeclaration; } }; + class VariableDeclaration : public TypedNode { + + Scope* TheScope = nullptr; + + public: + + bool IsCycleActive = false; + + class PubKeyword* PubKeyword; + class LetKeyword* LetKeyword; + class MutKeyword* MutKeyword; + class Pattern* Pattern; + std::vector Params; + class TypeAssert* TypeAssert; + LetBody* Body; + + VariableDeclaration( + class PubKeyword* PubKeyword, + class LetKeyword* LetKeyword, + class MutKeyword* MutKeyword, + class Pattern* Pattern, + class TypeAssert* TypeAssert, + LetBody* Body + ): TypedNode(NodeKind::VariableDeclaration), + PubKeyword(PubKeyword), + LetKeyword(LetKeyword), + MutKeyword(MutKeyword), + Pattern(Pattern), + TypeAssert(TypeAssert), + Body(Body) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + }; + class InstanceDeclaration : public Node { public: @@ -1977,6 +2018,7 @@ namespace bolt { template<> inline NodeKind getNodeType() { return NodeKind::RArrow; } template<> inline NodeKind getNodeType() { return NodeKind::RArrowAlt; } template<> inline NodeKind getNodeType() { return NodeKind::LetKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::FnKeyword; } template<> inline NodeKind getNodeType() { return NodeKind::MutKeyword; } template<> inline NodeKind getNodeType() { return NodeKind::PubKeyword; } template<> inline NodeKind getNodeType() { return NodeKind::TypeKeyword; } @@ -2019,7 +2061,8 @@ namespace bolt { template<> inline NodeKind getNodeType() { return NodeKind::Parameter; } template<> inline NodeKind getNodeType() { return NodeKind::LetBlockBody; } template<> inline NodeKind getNodeType() { return NodeKind::LetExprBody; } - template<> inline NodeKind getNodeType() { return NodeKind::LetDeclaration; } + template<> inline NodeKind getNodeType() { return NodeKind::FunctionDeclaration; } + template<> inline NodeKind getNodeType() { return NodeKind::VariableDeclaration; } template<> inline NodeKind getNodeType() { return NodeKind::RecordDeclarationField; } template<> inline NodeKind getNodeType() { return NodeKind::RecordDeclaration; } template<> inline NodeKind getNodeType() { return NodeKind::ClassDeclaration; } diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index 753826fbe..74b1d9923 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -12,170 +12,96 @@ namespace bolt { public: void visit(Node* N) { + +#define BOLT_GEN_CASE(name) \ + case NodeKind::name: \ + return static_cast(this)->visit ## name(static_cast(N)); + switch (N->getKind()) { - case NodeKind::Equals: - return static_cast(this)->visitEquals(static_cast(N)); - case NodeKind::Colon: - return static_cast(this)->visitColon(static_cast(N)); - case NodeKind::Comma: - return static_cast(this)->visitComma(static_cast(N)); - case NodeKind::Dot: - return static_cast(this)->visitDot(static_cast(N)); - case NodeKind::DotDot: - return static_cast(this)->visitDotDot(static_cast(N)); - case NodeKind::Tilde: - return static_cast(this)->visitTilde(static_cast(N)); - case NodeKind::LParen: - return static_cast(this)->visitLParen(static_cast(N)); - case NodeKind::RParen: - return static_cast(this)->visitRParen(static_cast(N)); - case NodeKind::LBracket: - return static_cast(this)->visitLBracket(static_cast(N)); - case NodeKind::RBracket: - return static_cast(this)->visitRBracket(static_cast(N)); - case NodeKind::LBrace: - return static_cast(this)->visitLBrace(static_cast(N)); - case NodeKind::RBrace: - return static_cast(this)->visitRBrace(static_cast(N)); - case NodeKind::RArrow: - return static_cast(this)->visitRArrow(static_cast(N)); - case NodeKind::RArrowAlt: - return static_cast(this)->visitRArrowAlt(static_cast(N)); - case NodeKind::LetKeyword: - return static_cast(this)->visitLetKeyword(static_cast(N)); - case NodeKind::MutKeyword: - return static_cast(this)->visitMutKeyword(static_cast(N)); - case NodeKind::PubKeyword: - return static_cast(this)->visitPubKeyword(static_cast(N)); - case NodeKind::TypeKeyword: - return static_cast(this)->visitTypeKeyword(static_cast(N)); - case NodeKind::ReturnKeyword: - return static_cast(this)->visitReturnKeyword(static_cast(N)); - case NodeKind::ModKeyword: - return static_cast(this)->visitModKeyword(static_cast(N)); - case NodeKind::StructKeyword: - return static_cast(this)->visitStructKeyword(static_cast(N)); - case NodeKind::EnumKeyword: - return static_cast(this)->visitEnumKeyword(static_cast(N)); - case NodeKind::ClassKeyword: - return static_cast(this)->visitClassKeyword(static_cast(N)); - case NodeKind::InstanceKeyword: - return static_cast(this)->visitInstanceKeyword(static_cast(N)); - case NodeKind::ElifKeyword: - return static_cast(this)->visitElifKeyword(static_cast(N)); - case NodeKind::IfKeyword: - return static_cast(this)->visitIfKeyword(static_cast(N)); - case NodeKind::ElseKeyword: - return static_cast(this)->visitElseKeyword(static_cast(N)); - case NodeKind::MatchKeyword: - return static_cast(this)->visitMatchKeyword(static_cast(N)); - case NodeKind::Invalid: - return static_cast(this)->visitInvalid(static_cast(N)); - case NodeKind::EndOfFile: - return static_cast(this)->visitEndOfFile(static_cast(N)); - case NodeKind::BlockStart: - return static_cast(this)->visitBlockStart(static_cast(N)); - case NodeKind::BlockEnd: - return static_cast(this)->visitBlockEnd(static_cast(N)); - case NodeKind::LineFoldEnd: - return static_cast(this)->visitLineFoldEnd(static_cast(N)); - case NodeKind::CustomOperator: - return static_cast(this)->visitCustomOperator(static_cast(N)); - case NodeKind::Assignment: - return static_cast(this)->visitAssignment(static_cast(N)); - case NodeKind::Identifier: - return static_cast(this)->visitIdentifier(static_cast(N)); - case NodeKind::IdentifierAlt: - return static_cast(this)->visitIdentifierAlt(static_cast(N)); - case NodeKind::StringLiteral: - return static_cast(this)->visitStringLiteral(static_cast(N)); - case NodeKind::IntegerLiteral: - return static_cast(this)->visitIntegerLiteral(static_cast(N)); - case NodeKind::TypeclassConstraintExpression: - return static_cast(this)->visitTypeclassConstraintExpression(static_cast(N)); - case NodeKind::EqualityConstraintExpression: - return static_cast(this)->visitEqualityConstraintExpression(static_cast(N)); - case NodeKind::QualifiedTypeExpression: - return static_cast(this)->visitQualifiedTypeExpression(static_cast(N)); - case NodeKind::ReferenceTypeExpression: - return static_cast(this)->visitReferenceTypeExpression(static_cast(N)); - case NodeKind::ArrowTypeExpression: - return static_cast(this)->visitArrowTypeExpression(static_cast(N)); - case NodeKind::AppTypeExpression: - return static_cast(this)->visitAppTypeExpression(static_cast(N)); - case NodeKind::VarTypeExpression: - return static_cast(this)->visitVarTypeExpression(static_cast(N)); - case NodeKind::NestedTypeExpression: - return static_cast(this)->visitNestedTypeExpression(static_cast(N)); - case NodeKind::TupleTypeExpression: - return static_cast(this)->visitTupleTypeExpression(static_cast(N)); - case NodeKind::BindPattern: - return static_cast(this)->visitBindPattern(static_cast(N)); - case NodeKind::LiteralPattern: - return static_cast(this)->visitLiteralPattern(static_cast(N)); - case NodeKind::NamedPattern: - return static_cast(this)->visitNamedPattern(static_cast(N)); - case NodeKind::NestedPattern: - return static_cast(this)->visitNestedPattern(static_cast(N)); - case NodeKind::ReferenceExpression: - return static_cast(this)->visitReferenceExpression(static_cast(N)); - case NodeKind::MatchCase: - return static_cast(this)->visitMatchCase(static_cast(N)); - case NodeKind::MatchExpression: - return static_cast(this)->visitMatchExpression(static_cast(N)); - case NodeKind::MemberExpression: - return static_cast(this)->visitMemberExpression(static_cast(N)); - case NodeKind::TupleExpression: - return static_cast(this)->visitTupleExpression(static_cast(N)); - case NodeKind::NestedExpression: - return static_cast(this)->visitNestedExpression(static_cast(N)); - case NodeKind::ConstantExpression: - return static_cast(this)->visitConstantExpression(static_cast(N)); - case NodeKind::CallExpression: - return static_cast(this)->visitCallExpression(static_cast(N)); - case NodeKind::InfixExpression: - return static_cast(this)->visitInfixExpression(static_cast(N)); - case NodeKind::PrefixExpression: - return static_cast(this)->visitPrefixExpression(static_cast(N)); - case NodeKind::RecordExpressionField: - return static_cast(this)->visitRecordExpressionField(static_cast(N)); - case NodeKind::RecordExpression: - return static_cast(this)->visitRecordExpression(static_cast(N)); - case NodeKind::ExpressionStatement: - return static_cast(this)->visitExpressionStatement(static_cast(N)); - case NodeKind::ReturnStatement: - return static_cast(this)->visitReturnStatement(static_cast(N)); - case NodeKind::IfStatement: - return static_cast(this)->visitIfStatement(static_cast(N)); - case NodeKind::IfStatementPart: - return static_cast(this)->visitIfStatementPart(static_cast(N)); - case NodeKind::TypeAssert: - return static_cast(this)->visitTypeAssert(static_cast(N)); - case NodeKind::Parameter: - return static_cast(this)->visitParameter(static_cast(N)); - case NodeKind::LetBlockBody: - return static_cast(this)->visitLetBlockBody(static_cast(N)); - case NodeKind::LetExprBody: - return static_cast(this)->visitLetExprBody(static_cast(N)); - case NodeKind::LetDeclaration: - return static_cast(this)->visitLetDeclaration(static_cast(N)); - case NodeKind::RecordDeclarationField: - return static_cast(this)->visitRecordDeclarationField(static_cast(N)); - case NodeKind::RecordDeclaration: - return static_cast(this)->visitRecordDeclaration(static_cast(N)); - case NodeKind::VariantDeclaration: - return static_cast(this)->visitVariantDeclaration(static_cast(N)); - case NodeKind::TupleVariantDeclarationMember: - return static_cast(this)->visitTupleVariantDeclarationMember(static_cast(N)); - case NodeKind::RecordVariantDeclarationMember: - return static_cast(this)->visitRecordVariantDeclarationMember(static_cast(N)); - case NodeKind::ClassDeclaration: - return static_cast(this)->visitClassDeclaration(static_cast(N)); - case NodeKind::InstanceDeclaration: - return static_cast(this)->visitInstanceDeclaration(static_cast(N)); - case NodeKind::SourceFile: - return static_cast(this)->visitSourceFile(static_cast(N)); - } + BOLT_GEN_CASE(Equals) + BOLT_GEN_CASE(Colon) + BOLT_GEN_CASE(Comma) + BOLT_GEN_CASE(Dot) + BOLT_GEN_CASE(DotDot) + BOLT_GEN_CASE(Tilde) + BOLT_GEN_CASE(LParen) + BOLT_GEN_CASE(RParen) + BOLT_GEN_CASE(LBracket) + BOLT_GEN_CASE(RBracket) + BOLT_GEN_CASE(LBrace) + BOLT_GEN_CASE(RBrace) + BOLT_GEN_CASE(RArrow) + BOLT_GEN_CASE(RArrowAlt) + BOLT_GEN_CASE(LetKeyword) + BOLT_GEN_CASE(FnKeyword) + BOLT_GEN_CASE(MutKeyword) + BOLT_GEN_CASE(PubKeyword) + BOLT_GEN_CASE(TypeKeyword) + BOLT_GEN_CASE(ReturnKeyword) + BOLT_GEN_CASE(ModKeyword) + BOLT_GEN_CASE(StructKeyword) + BOLT_GEN_CASE(EnumKeyword) + BOLT_GEN_CASE(ClassKeyword) + BOLT_GEN_CASE(InstanceKeyword) + BOLT_GEN_CASE(ElifKeyword) + BOLT_GEN_CASE(IfKeyword) + BOLT_GEN_CASE(ElseKeyword) + BOLT_GEN_CASE(MatchKeyword) + BOLT_GEN_CASE(Invalid) + BOLT_GEN_CASE(EndOfFile) + BOLT_GEN_CASE(BlockStart) + BOLT_GEN_CASE(BlockEnd) + BOLT_GEN_CASE(LineFoldEnd) + BOLT_GEN_CASE(CustomOperator) + BOLT_GEN_CASE(Assignment) + BOLT_GEN_CASE(Identifier) + BOLT_GEN_CASE(IdentifierAlt) + BOLT_GEN_CASE(StringLiteral) + BOLT_GEN_CASE(IntegerLiteral) + BOLT_GEN_CASE(TypeclassConstraintExpression) + BOLT_GEN_CASE(EqualityConstraintExpression) + BOLT_GEN_CASE(QualifiedTypeExpression) + BOLT_GEN_CASE(ReferenceTypeExpression) + BOLT_GEN_CASE(ArrowTypeExpression) + BOLT_GEN_CASE(AppTypeExpression) + BOLT_GEN_CASE(VarTypeExpression) + BOLT_GEN_CASE(NestedTypeExpression) + BOLT_GEN_CASE(TupleTypeExpression) + BOLT_GEN_CASE(BindPattern) + BOLT_GEN_CASE(LiteralPattern) + BOLT_GEN_CASE(NamedPattern) + BOLT_GEN_CASE(NestedPattern) + BOLT_GEN_CASE(ReferenceExpression) + BOLT_GEN_CASE(MatchCase) + BOLT_GEN_CASE(MatchExpression) + BOLT_GEN_CASE(MemberExpression) + BOLT_GEN_CASE(TupleExpression) + BOLT_GEN_CASE(NestedExpression) + BOLT_GEN_CASE(ConstantExpression) + BOLT_GEN_CASE(CallExpression) + BOLT_GEN_CASE(InfixExpression) + BOLT_GEN_CASE(PrefixExpression) + BOLT_GEN_CASE(RecordExpressionField) + BOLT_GEN_CASE(RecordExpression) + BOLT_GEN_CASE(ExpressionStatement) + BOLT_GEN_CASE(ReturnStatement) + BOLT_GEN_CASE(IfStatement) + BOLT_GEN_CASE(IfStatementPart) + BOLT_GEN_CASE(TypeAssert) + BOLT_GEN_CASE(Parameter) + BOLT_GEN_CASE(LetBlockBody) + BOLT_GEN_CASE(LetExprBody) + BOLT_GEN_CASE(FunctionDeclaration) + BOLT_GEN_CASE(VariableDeclaration) + BOLT_GEN_CASE(RecordDeclaration) + BOLT_GEN_CASE(RecordDeclarationField) + BOLT_GEN_CASE(VariantDeclaration) + BOLT_GEN_CASE(TupleVariantDeclarationMember) + BOLT_GEN_CASE(RecordVariantDeclarationMember) + BOLT_GEN_CASE(ClassDeclaration) + BOLT_GEN_CASE(InstanceDeclaration) + BOLT_GEN_CASE(SourceFile) + } } protected: @@ -248,6 +174,10 @@ namespace bolt { visitToken(N); } + void visitFnKeyword(FnKeyword* N) { + visitToken(N); + } + void visitMutKeyword(MutKeyword* N) { visitToken(N); } @@ -500,7 +430,11 @@ namespace bolt { visitLetBody(N); } - void visitLetDeclaration(LetDeclaration* N) { + void visitFunctionDeclaration(FunctionDeclaration* N) { + visitNode(N); + } + + void visitVariableDeclaration(VariableDeclaration* N) { visitNode(N); } @@ -543,252 +477,96 @@ namespace bolt { public: void visitEachChild(Node* N) { + +#define BOLT_GEN_CHILD_CASE(name) \ + case NodeKind::name: \ + visitEachChild(static_cast(N)); \ + break; + switch (N->getKind()) { - case NodeKind::Equals: - visitEachChild(static_cast(N)); - break; - case NodeKind::Colon: - visitEachChild(static_cast(N)); - break; - case NodeKind::Comma: - visitEachChild(static_cast(N)); - break; - case NodeKind::Dot: - visitEachChild(static_cast(N)); - break; - case NodeKind::DotDot: - visitEachChild(static_cast(N)); - break; - case NodeKind::Tilde: - visitEachChild(static_cast(N)); - break; - case NodeKind::LParen: - visitEachChild(static_cast(N)); - break; - case NodeKind::RParen: - visitEachChild(static_cast(N)); - break; - case NodeKind::LBracket: - visitEachChild(static_cast(N)); - break; - case NodeKind::RBracket: - visitEachChild(static_cast(N)); - break; - case NodeKind::LBrace: - visitEachChild(static_cast(N)); - break; - case NodeKind::RBrace: - visitEachChild(static_cast(N)); - break; - case NodeKind::RArrow: - visitEachChild(static_cast(N)); - break; - case NodeKind::RArrowAlt: - visitEachChild(static_cast(N)); - break; - case NodeKind::LetKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::MutKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::PubKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::TypeKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::ReturnKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::ModKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::StructKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::EnumKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::ClassKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::InstanceKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::ElifKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::IfKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::ElseKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::MatchKeyword: - visitEachChild(static_cast(N)); - break; - case NodeKind::Invalid: - visitEachChild(static_cast(N)); - break; - case NodeKind::EndOfFile: - visitEachChild(static_cast(N)); - break; - case NodeKind::BlockStart: - visitEachChild(static_cast(N)); - break; - case NodeKind::BlockEnd: - visitEachChild(static_cast(N)); - break; - case NodeKind::LineFoldEnd: - visitEachChild(static_cast(N)); - break; - case NodeKind::CustomOperator: - visitEachChild(static_cast(N)); - break; - case NodeKind::Assignment: - visitEachChild(static_cast(N)); - break; - case NodeKind::Identifier: - visitEachChild(static_cast(N)); - break; - case NodeKind::IdentifierAlt: - visitEachChild(static_cast(N)); - break; - case NodeKind::StringLiteral: - visitEachChild(static_cast(N)); - break; - case NodeKind::IntegerLiteral: - visitEachChild(static_cast(N)); - break; - case NodeKind::TypeclassConstraintExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::EqualityConstraintExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::QualifiedTypeExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::ReferenceTypeExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::ArrowTypeExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::AppTypeExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::VarTypeExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::NestedTypeExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::TupleTypeExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::BindPattern: - visitEachChild(static_cast(N)); - break; - case NodeKind::LiteralPattern: - visitEachChild(static_cast(N)); - break; - case NodeKind::NamedPattern: - visitEachChild(static_cast(N)); - break; - case NodeKind::NestedPattern: - visitEachChild(static_cast(N)); - break; - case NodeKind::ReferenceExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::MatchCase: - visitEachChild(static_cast(N)); - break; - case NodeKind::MatchExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::MemberExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::TupleExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::NestedExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::ConstantExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::CallExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::InfixExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::PrefixExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::RecordExpressionField: - visitEachChild(static_cast(N)); - break; - case NodeKind::RecordExpression: - visitEachChild(static_cast(N)); - break; - case NodeKind::ExpressionStatement: - visitEachChild(static_cast(N)); - break; - case NodeKind::ReturnStatement: - visitEachChild(static_cast(N)); - break; - case NodeKind::IfStatement: - visitEachChild(static_cast(N)); - break; - case NodeKind::IfStatementPart: - visitEachChild(static_cast(N)); - break; - case NodeKind::TypeAssert: - visitEachChild(static_cast(N)); - break; - case NodeKind::Parameter: - visitEachChild(static_cast(N)); - break; - case NodeKind::LetBlockBody: - visitEachChild(static_cast(N)); - break; - case NodeKind::LetExprBody: - visitEachChild(static_cast(N)); - break; - case NodeKind::LetDeclaration: - visitEachChild(static_cast(N)); - break; - case NodeKind::RecordDeclaration: - visitEachChild(static_cast(N)); - break; - case NodeKind::RecordDeclarationField: - visitEachChild(static_cast(N)); - break; - case NodeKind::VariantDeclaration: - visitEachChild(static_cast(N)); - break; - case NodeKind::TupleVariantDeclarationMember: - visitEachChild(static_cast(N)); - break; - case NodeKind::RecordVariantDeclarationMember: - visitEachChild(static_cast(N)); - break; - case NodeKind::ClassDeclaration: - visitEachChild(static_cast(N)); - break; - case NodeKind::InstanceDeclaration: - visitEachChild(static_cast(N)); - break; - case NodeKind::SourceFile: - visitEachChild(static_cast(N)); - break; - default: - ZEN_UNREACHABLE + BOLT_GEN_CHILD_CASE(Equals) + BOLT_GEN_CHILD_CASE(Colon) + BOLT_GEN_CHILD_CASE(Comma) + BOLT_GEN_CHILD_CASE(Dot) + BOLT_GEN_CHILD_CASE(DotDot) + BOLT_GEN_CHILD_CASE(Tilde) + BOLT_GEN_CHILD_CASE(LParen) + BOLT_GEN_CHILD_CASE(RParen) + BOLT_GEN_CHILD_CASE(LBracket) + BOLT_GEN_CHILD_CASE(RBracket) + BOLT_GEN_CHILD_CASE(LBrace) + BOLT_GEN_CHILD_CASE(RBrace) + BOLT_GEN_CHILD_CASE(RArrow) + BOLT_GEN_CHILD_CASE(RArrowAlt) + BOLT_GEN_CHILD_CASE(LetKeyword) + BOLT_GEN_CHILD_CASE(FnKeyword) + BOLT_GEN_CHILD_CASE(MutKeyword) + BOLT_GEN_CHILD_CASE(PubKeyword) + BOLT_GEN_CHILD_CASE(TypeKeyword) + BOLT_GEN_CHILD_CASE(ReturnKeyword) + BOLT_GEN_CHILD_CASE(ModKeyword) + BOLT_GEN_CHILD_CASE(StructKeyword) + BOLT_GEN_CHILD_CASE(EnumKeyword) + BOLT_GEN_CHILD_CASE(ClassKeyword) + BOLT_GEN_CHILD_CASE(InstanceKeyword) + BOLT_GEN_CHILD_CASE(ElifKeyword) + BOLT_GEN_CHILD_CASE(IfKeyword) + BOLT_GEN_CHILD_CASE(ElseKeyword) + BOLT_GEN_CHILD_CASE(MatchKeyword) + BOLT_GEN_CHILD_CASE(Invalid) + BOLT_GEN_CHILD_CASE(EndOfFile) + BOLT_GEN_CHILD_CASE(BlockStart) + BOLT_GEN_CHILD_CASE(BlockEnd) + BOLT_GEN_CHILD_CASE(LineFoldEnd) + BOLT_GEN_CHILD_CASE(CustomOperator) + BOLT_GEN_CHILD_CASE(Assignment) + BOLT_GEN_CHILD_CASE(Identifier) + BOLT_GEN_CHILD_CASE(IdentifierAlt) + BOLT_GEN_CHILD_CASE(StringLiteral) + BOLT_GEN_CHILD_CASE(IntegerLiteral) + BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression) + BOLT_GEN_CHILD_CASE(EqualityConstraintExpression) + BOLT_GEN_CHILD_CASE(QualifiedTypeExpression) + BOLT_GEN_CHILD_CASE(ReferenceTypeExpression) + BOLT_GEN_CHILD_CASE(ArrowTypeExpression) + BOLT_GEN_CHILD_CASE(AppTypeExpression) + BOLT_GEN_CHILD_CASE(VarTypeExpression) + BOLT_GEN_CHILD_CASE(NestedTypeExpression) + BOLT_GEN_CHILD_CASE(TupleTypeExpression) + BOLT_GEN_CHILD_CASE(BindPattern) + BOLT_GEN_CHILD_CASE(LiteralPattern) + BOLT_GEN_CHILD_CASE(NamedPattern) + BOLT_GEN_CHILD_CASE(NestedPattern) + BOLT_GEN_CHILD_CASE(ReferenceExpression) + BOLT_GEN_CHILD_CASE(MatchCase) + BOLT_GEN_CHILD_CASE(MatchExpression) + BOLT_GEN_CHILD_CASE(MemberExpression) + BOLT_GEN_CHILD_CASE(TupleExpression) + BOLT_GEN_CHILD_CASE(NestedExpression) + BOLT_GEN_CHILD_CASE(ConstantExpression) + BOLT_GEN_CHILD_CASE(CallExpression) + BOLT_GEN_CHILD_CASE(InfixExpression) + BOLT_GEN_CHILD_CASE(PrefixExpression) + BOLT_GEN_CHILD_CASE(RecordExpressionField) + BOLT_GEN_CHILD_CASE(RecordExpression) + BOLT_GEN_CHILD_CASE(ExpressionStatement) + BOLT_GEN_CHILD_CASE(ReturnStatement) + BOLT_GEN_CHILD_CASE(IfStatement) + BOLT_GEN_CHILD_CASE(IfStatementPart) + BOLT_GEN_CHILD_CASE(TypeAssert) + BOLT_GEN_CHILD_CASE(Parameter) + BOLT_GEN_CHILD_CASE(LetBlockBody) + BOLT_GEN_CHILD_CASE(LetExprBody) + BOLT_GEN_CHILD_CASE(FunctionDeclaration) + BOLT_GEN_CHILD_CASE(VariableDeclaration) + BOLT_GEN_CHILD_CASE(RecordDeclaration) + BOLT_GEN_CHILD_CASE(RecordDeclarationField) + BOLT_GEN_CHILD_CASE(VariantDeclaration) + BOLT_GEN_CHILD_CASE(TupleVariantDeclarationMember) + BOLT_GEN_CHILD_CASE(RecordVariantDeclarationMember) + BOLT_GEN_CHILD_CASE(ClassDeclaration) + BOLT_GEN_CHILD_CASE(InstanceDeclaration) + BOLT_GEN_CHILD_CASE(SourceFile) } } @@ -839,6 +617,9 @@ namespace bolt { void visitEachChild(LetKeyword* N) { } + void visitEachChild(FnKeyword* N) { + } + void visitEachChild(MutKeyword* N) { } @@ -1136,18 +917,29 @@ namespace bolt { BOLT_VISIT(N->Expression); } - void visitEachChild(LetDeclaration* N) { + void visitEachChild(FunctionDeclaration* N) { + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + BOLT_VISIT(N->FnKeyword); + BOLT_VISIT(N->Name); + for (auto Param: N->Params) { + BOLT_VISIT(Param); + } + if (N->TypeAssert) { + BOLT_VISIT(N->TypeAssert); + } + if (N->Body) { + BOLT_VISIT(N->Body); + } + } + + void visitEachChild(VariableDeclaration* N) { if (N->PubKeyword) { BOLT_VISIT(N->PubKeyword); } BOLT_VISIT(N->LetKeyword); - if (N->MutKeyword) { - BOLT_VISIT(N->MutKeyword); - } BOLT_VISIT(N->Pattern); - for (auto Param: N->Params) { - BOLT_VISIT(Param); - } if (N->TypeAssert) { BOLT_VISIT(N->TypeAssert); } diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index 2edcdceec..eac01a147 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -197,16 +197,16 @@ namespace bolt { void addConstraint(Constraint* Constraint); void forwardDeclare(Node* Node); - void forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); + void forwardDeclareFunctionDeclaration(FunctionDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); Type* inferExpression(Expression* Expression); - Type* inferTypeExpression(TypeExpression* TE); + Type* inferTypeExpression(TypeExpression* TE, bool IsPoly = true); Type* inferLiteral(Literal* Lit); Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); void infer(Node* node); - void inferLetDeclaration(LetDeclaration* N); + void inferFunctionDeclaration(FunctionDeclaration* N); Constraint* convertToConstraint(ConstraintExpression* C); diff --git a/include/bolt/Common.hpp b/include/bolt/Common.hpp index e9ca4e1db..2b247ef5d 100644 --- a/include/bolt/Common.hpp +++ b/include/bolt/Common.hpp @@ -9,7 +9,7 @@ namespace bolt { ConfigFlags_TypeVarsRequireForall = 1 << 0, }; - unsigned Flags; + unsigned Flags = 0; public: diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index 28763dfc7..0ba01ead0 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -109,9 +109,9 @@ namespace bolt { public: TypeclassSignature Sig; - LetDeclaration* Decl; + FunctionDeclaration* Decl; - inline TypeclassMissingDiagnostic(TypeclassSignature Sig, LetDeclaration* Decl): + inline TypeclassMissingDiagnostic(TypeclassSignature Sig, FunctionDeclaration* Decl): Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {} inline Node* getNode() const override { diff --git a/include/bolt/Parser.hpp b/include/bolt/Parser.hpp index ef20b386d..9be4f7caa 100644 --- a/include/bolt/Parser.hpp +++ b/include/bolt/Parser.hpp @@ -124,7 +124,9 @@ namespace bolt { Node* parseLetBodyElement(); - LetDeclaration* parseLetDeclaration(); + FunctionDeclaration* parseFunctionDeclaration(); + + VariableDeclaration* parseVariableDeclaration(); Node* parseClassElement(); diff --git a/src/CST.cc b/src/CST.cc index 8856a8ea8..a18e956e9 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -69,9 +69,9 @@ namespace bolt { } break; } - case NodeKind::LetDeclaration: + case NodeKind::FunctionDeclaration: { - auto Decl = static_cast(X); + auto Decl = static_cast(X); for (auto Param: Decl->Params) { visitPattern(Param->Pattern, Param); } @@ -112,12 +112,18 @@ namespace bolt { } break; } - case NodeKind::LetDeclaration: + case NodeKind::VariableDeclaration: { - auto Decl = static_cast(X); + auto Decl = static_cast(X); visitPattern(Decl->Pattern, Decl); break; } + case NodeKind::FunctionDeclaration: + { + auto Decl = static_cast(X); + addSymbol(Decl->Name->getCanonicalText(), Decl, SymbolKind::Var); + break; + } case NodeKind::RecordDeclaration: { auto Decl = static_cast(X); @@ -597,14 +603,14 @@ namespace bolt { return Expression->getLastToken(); } - Token* LetDeclaration::getFirstToken() const { + Token* FunctionDeclaration::getFirstToken() const { if (PubKeyword) { return PubKeyword; } - return LetKeyword; + return FnKeyword; } - Token* LetDeclaration::getLastToken() const { + Token* FunctionDeclaration::getLastToken() const { if (Body) { return Body->getLastToken(); } @@ -614,6 +620,23 @@ namespace bolt { if (Params.size()) { return Params.back()->getLastToken(); } + return Name; + } + + Token* VariableDeclaration::getFirstToken() const { + if (PubKeyword) { + return PubKeyword; + } + return LetKeyword; + } + + Token* VariableDeclaration::getLastToken() const { + if (Body) { + return Body->getLastToken(); + } + if (TypeAssert) { + return TypeAssert->getLastToken(); + } return Pattern->getLastToken(); } @@ -766,6 +789,10 @@ namespace bolt { return "let"; } + std::string FnKeyword::getText() const { + return "fn"; + } + std::string MutKeyword::getText() const { return "mut"; } diff --git a/src/Checker.cc b/src/Checker.cc index 395ea5f31..961c2439a 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -12,6 +12,12 @@ // TODO see if we can merge UnificationError diagnostics so that we get a list of **all** types that were wrong on a given node +// TODO When a forall variable is missing, do not just insert a blank one into the env. It will result in too few diagnostics being emitted. +// Same goes for reference expressions. +// If running the compiler as a language server, this matters. + +// TODO Add a pattern that only performs a type assert + #include #include #include @@ -296,10 +302,14 @@ namespace bolt { break; } - case NodeKind::LetDeclaration: + case NodeKind::FunctionDeclaration: // These declarations will be handled separately in check() break; + case NodeKind::VariableDeclaration: + // All of this node's semantics will be handled in infer() + break; + case NodeKind::VariantDeclaration: { auto Decl = static_cast(X); @@ -376,8 +386,8 @@ namespace bolt { for (auto TV: Vars) { RetTy = new TApp(RetTy, TV); } + Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy))); popContext(); - addBinding(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy))); break; } @@ -423,18 +433,18 @@ namespace bolt { Contexts.pop(); } - void visitLetDeclaration(LetDeclaration* Let) { - if (Let->isFunc()) { - Let->Ctx = createDerivedContext(); - Contexts.push(Let->Ctx); - visitEachChild(Let); - Contexts.pop(); - } else { - Let->Ctx = Contexts.top(); - visitEachChild(Let); - } + void visitFunctionDeclaration(FunctionDeclaration* Let) { + Let->Ctx = createDerivedContext(); + Contexts.push(Let->Ctx); + visitEachChild(Let); + Contexts.pop(); } + // void visitVariableDeclaration(VariableDeclaration* Var) { + // Var->Ctx = Contexts.top(); + // visitEachChild(Var); + // } + }; Init I { {}, *this }; @@ -442,9 +452,7 @@ namespace bolt { } - void Checker::forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints) { - - auto Let = static_cast(N); + void Checker::forwardDeclareFunctionDeclaration(FunctionDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { setContext(Let->Ctx); @@ -495,7 +503,7 @@ namespace bolt { Params.push_back(TV); } - auto SigLet = llvm::cast(Class->getScope()->lookupDirect({ {}, llvm::cast(Let->Pattern)->Name->getCanonicalText() }, SymbolKind::Var)); + auto SigLet = llvm::cast(Class->getScope()->lookupDirect({ {}, Let->Name->getCanonicalText() }, SymbolKind::Var)); // It would be very strange if there was no type assert in the type // class let-declaration but we rather not let the compiler crash if that happens. @@ -520,9 +528,7 @@ namespace bolt { case NodeKind::LetBlockBody: { auto Block = static_cast(Let->Body); - if (Let->isFunc()) { - Let->Ctx->ReturnType = createTypeVar(); - } + Let->Ctx->ReturnType = createTypeVar(); for (auto Element: Block->Elements) { forwardDeclare(Element); } @@ -533,19 +539,11 @@ namespace bolt { } } - - Type* BindTy; - if (Let->isFunc()) { - popContext(); - BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs); - } else { - BindTy = inferPattern(Let->Pattern); - } - addConstraint(new CEqual(BindTy, Ty, Let)); + Let->Ctx->Parent->Env.emplace(Let->Name->getCanonicalText(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty)); } - void Checker::inferLetDeclaration(LetDeclaration* Decl) { + void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) { setContext(Decl->Ctx); @@ -553,7 +551,6 @@ namespace bolt { Type* RetType; for (auto Param: Decl->Params) { - // TODO incorporate Param->TypeAssert or make it a kind of pattern ParamTypes.push_back(inferPattern(Param->Pattern)); } @@ -568,7 +565,6 @@ namespace bolt { case NodeKind::LetBlockBody: { auto Block = static_cast(Decl->Body); - ZEN_ASSERT(Decl->isFunc()); RetType = Decl->Ctx->ReturnType; for (auto Element: Block->Elements) { infer(Element); @@ -582,13 +578,7 @@ namespace bolt { RetType = createTypeVar(); } - if (Decl->isFunc()) { - popContext(); - addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl }); - } else { - // Declaration is a plain (typed) variable - addConstraint(new CEqual { Decl->Ty, RetType, Decl }); - } + addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl }); } @@ -642,7 +632,7 @@ namespace bolt { break; } - case NodeKind::LetDeclaration: + case NodeKind::FunctionDeclaration: break; case NodeKind::ReturnStatement: @@ -658,6 +648,33 @@ namespace bolt { break; } + case NodeKind::VariableDeclaration: + { + auto Decl = static_cast(N); + Type* Ty = nullptr; + if (Decl->TypeAssert) { + Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression, false); + } + if (Decl->Body) { + ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); + auto E = static_cast(Decl->Body); + auto Ty2 = inferExpression(E->Expression); + if (Ty) { + addConstraint(new CEqual(Ty, Ty2, Decl)); + } else { + Ty = Ty2; + } + } + auto Ty3 = inferPattern(Decl->Pattern); + if (Ty) { + addConstraint(new CEqual(Ty, Ty3, Decl)); + } else { + Ty = Ty3; + } + Decl->setType(Ty); + break; + } + case NodeKind::ExpressionStatement: { auto ExprStmt = static_cast(N); @@ -764,7 +781,7 @@ namespace bolt { } } - Type* Checker::inferTypeExpression(TypeExpression* N) { + Type* Checker::inferTypeExpression(TypeExpression* N, bool IsPoly) { switch (N->getKind()) { @@ -786,9 +803,9 @@ namespace bolt { case NodeKind::AppTypeExpression: { auto AppTE = static_cast(N); - Type* Ty = inferTypeExpression(AppTE->Op); + Type* Ty = inferTypeExpression(AppTE->Op, IsPoly); for (auto Arg: AppTE->Args) { - Ty = new TApp(Ty, inferTypeExpression(Arg)); + Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly)); } return Ty; } @@ -798,10 +815,10 @@ namespace bolt { auto VarTE = static_cast(N); auto Ty = lookupMono(VarTE->Name->getCanonicalText()); if (Ty == nullptr) { - if (Config.typeVarsRequireForall()) { + if (IsPoly && Config.typeVarsRequireForall()) { DE.add(VarTE->Name->getCanonicalText(), VarTE->Name); } - Ty = createRigidVar(VarTE->Name->getCanonicalText()); + Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar(); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); } ZEN_ASSERT(Ty->getKind() == TypeKind::Var); @@ -814,7 +831,7 @@ namespace bolt { auto TupleTE = static_cast(N); std::vector ElementTypes; for (auto [TE, Comma]: TupleTE->Elements) { - ElementTypes.push_back(inferTypeExpression(TE)); + ElementTypes.push_back(inferTypeExpression(TE, IsPoly)); } auto Ty = new TTuple(ElementTypes); N->setType(Ty); @@ -824,7 +841,7 @@ namespace bolt { case NodeKind::NestedTypeExpression: { auto NestedTE = static_cast(N); - auto Ty = inferTypeExpression(NestedTE->TE); + auto Ty = inferTypeExpression(NestedTE->TE, IsPoly); N->setType(Ty); return Ty; } @@ -834,9 +851,9 @@ namespace bolt { auto ArrowTE = static_cast(N); std::vector ParamTypes; for (auto ParamType: ArrowTE->ParamTypes) { - ParamTypes.push_back(inferTypeExpression(ParamType)); + ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly)); } - auto ReturnType = inferTypeExpression(ArrowTE->ReturnType); + auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly); auto Ty = new TArrow(ParamTypes, ReturnType); N->setType(Ty); return Ty; @@ -848,7 +865,7 @@ namespace bolt { for (auto [C, Comma]: QTE->Constraints) { addConstraint(convertToConstraint(C)); } - auto Ty = inferTypeExpression(QTE->TE); + auto Ty = inferTypeExpression(QTE->TE, IsPoly); N->setType(Ty); return Ty; } @@ -889,12 +906,13 @@ namespace bolt { } Ty = createTypeVar(); for (auto Case: Match->Cases) { + auto OldCtx = &getContext(); setContext(Case->Ctx); auto PattTy = inferPattern(Case->Pattern); addConstraint(new CEqual(PattTy, ValTy, Case)); auto ExprTy = inferExpression(Case->Expression); addConstraint(new CEqual(ExprTy, Ty, Case->Expression)); - popContext(); + setContext(OldCtx); } if (!Match->Value) { Ty = new TArrow({ ValTy }, Ty); @@ -925,8 +943,8 @@ namespace bolt { auto Ref = static_cast(X); ZEN_ASSERT(Ref->ModulePath.empty()); auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); - if (Target && llvm::isa(Target)) { - auto Let = static_cast(Target); + if (Target && llvm::isa(Target)) { + auto Let = static_cast(Target); if (Let->IsCycleActive) { return Let->Ty; } @@ -1100,7 +1118,7 @@ namespace bolt { std::stack Stack; - void visitLetDeclaration(LetDeclaration* N) { + void visitFunctionDeclaration(FunctionDeclaration* N) { RefGraph.addVertex(N); Stack.push(N); visitEachChild(N); @@ -1121,7 +1139,7 @@ namespace bolt { RefGraph.addEdge(Stack.top(), Def->Parent); return; } - ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration); + ZEN_ASSERT(Def->getKind() == NodeKind::FunctionDeclaration || Def->getKind() == NodeKind::VariableDeclaration); if (!Stack.empty()) { RefGraph.addEdge(Def, Stack.top()); } @@ -1140,7 +1158,7 @@ namespace bolt { Checker& C; - void visitLetDeclaration(LetDeclaration* Decl) { + void visitLetDeclaration(FunctionDeclaration* Decl) { // Only inspect those let-declarations that look like a function if (Decl->Params.empty()) { @@ -1289,26 +1307,26 @@ namespace bolt { auto TVs = new TVSet; auto Constraints = new ConstraintSet; for (auto N: Nodes) { - auto Decl = static_cast(N); - forwardDeclareLetDeclaration(Decl, TVs, Constraints); + auto Decl = static_cast(N); + forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); } } for (auto Nodes: SCCs) { for (auto N: Nodes) { - auto Decl = static_cast(N); + auto Decl = static_cast(N); Decl->IsCycleActive = true; } for (auto N: Nodes) { - auto Decl = static_cast(N); - inferLetDeclaration(Decl); + auto Decl = static_cast(N); + inferFunctionDeclaration(Decl); } for (auto N: Nodes) { - auto Decl = static_cast(N); + auto Decl = static_cast(N); Decl->IsCycleActive = false; } } + setContext(SF->Ctx); infer(SF); - popContext(); solve(new CMany(*SF->Ctx->Constraints)); checkTypeclassSigs(SF); } diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index e7dfeb9e8..2397d5751 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -100,6 +100,8 @@ namespace bolt { return "'pub'"; case NodeKind::LetKeyword: return "'let'"; + case NodeKind::FnKeyword: + return "'fn'"; case NodeKind::MutKeyword: return "'mut'"; case NodeKind::MatchKeyword: @@ -108,8 +110,10 @@ namespace bolt { return "'return'"; case NodeKind::TypeKeyword: return "'type'"; - case NodeKind::LetDeclaration: - return "a let-declaration"; + case NodeKind::FunctionDeclaration: + return "a function declaration"; + case NodeKind::VariableDeclaration: + return "a variable declaration"; case NodeKind::CallExpression: return "a call-expression"; case NodeKind::InfixExpression: diff --git a/src/Parser.cc b/src/Parser.cc index 0c60f35c0..7327291c8 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -853,7 +853,7 @@ finish: return new IfStatement(Parts); } - LetDeclaration* Parser::parseLetDeclaration() { +VariableDeclaration* Parser::parseVariableDeclaration() { PubKeyword* Pub = nullptr; LetKeyword* Let; @@ -881,8 +881,8 @@ finish: Tokens.get(); } - auto Patt = parsePattern(); - if (!Patt) { + auto P = parsePattern(); + if (!P) { if (Pub) { Pub->unref(); } @@ -894,27 +894,7 @@ finish: return nullptr; } - std::vector Params; - Token* T2; - for (;;) { - T2 = Tokens.peek(); - switch (T2->getKind()) { - case NodeKind::LineFoldEnd: - case NodeKind::BlockStart: - case NodeKind::Equals: - case NodeKind::Colon: - goto after_params; - default: - auto P = parsePattern(); - if (P == nullptr) { - P = new BindPattern(new Identifier("_")); - } - Params.push_back(new Parameter(P, nullptr)); - } - } - -after_params: - + auto T2 = Tokens.peek(); if (T2->getKind() == NodeKind::Colon) { Tokens.get(); auto TE = parseTypeExpression(); @@ -972,16 +952,137 @@ after_params: DE.add(File, T2, Expected); } -after_body: + checkLineFoldEnd(); + +finish: + + return new VariableDeclaration { Pub, Let, Mut, P, TA, Body }; + } + + FunctionDeclaration* Parser::parseFunctionDeclaration() { + + PubKeyword* Pub = nullptr; + FnKeyword* Fn; + MutKeyword* Mut = nullptr; + TypeAssert* TA = nullptr; + LetBody* Body = nullptr; + + auto T0 = Tokens.get(); + if (T0->getKind() == NodeKind::PubKeyword) { + Pub = static_cast(T0); + T0 = Tokens.get(); + } + if (T0->getKind() != NodeKind::FnKeyword) { + DE.add(File, T0, std::vector { NodeKind::FnKeyword }); + if (Pub) { + Pub->unref(); + } + skipToLineFoldEnd(); + return nullptr; + } + Fn = static_cast(T0); + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::MutKeyword) { + Mut = static_cast(T1); + Tokens.get(); + } + + auto Name = expectToken(); + if (!Name) { + if (Pub) { + Pub->unref(); + } + Fn->unref(); + if (Mut) { + Mut->unref(); + } + skipToLineFoldEnd(); + return nullptr; + } + + std::vector Params; + Token* T2; + for (;;) { + T2 = Tokens.peek(); + switch (T2->getKind()) { + case NodeKind::LineFoldEnd: + case NodeKind::BlockStart: + case NodeKind::Equals: + case NodeKind::Colon: + goto after_params; + default: + auto P = parsePattern(); + if (!P) { + P = new BindPattern(new Identifier("_")); + } + Params.push_back(new Parameter(P, nullptr)); + } + } + +after_params: + + if (T2->getKind() == NodeKind::Colon) { + Tokens.get(); + auto TE = parseTypeExpression(); + if (TE) { + TA = new TypeAssert(static_cast(T2), TE); + } else { + skipToLineFoldEnd(); + goto finish; + } + T2 = Tokens.peek(); + } + + switch (T2->getKind()) { + case NodeKind::BlockStart: + { + Tokens.get(); + std::vector Elements; + for (;;) { + auto T3 = Tokens.peek(); + if (T3->getKind() == NodeKind::BlockEnd) { + break; + } + auto Element = parseLetBodyElement(); + if (Element) { + Elements.push_back(Element); + } + } + Tokens.get()->unref(); // Always a BlockEnd + Body = new LetBlockBody(static_cast(T2), Elements); + break; + } + case NodeKind::Equals: + { + Tokens.get(); + auto E = parseExpression(); + if (!E) { + skipToLineFoldEnd(); + goto finish; + } + Body = new LetExprBody(static_cast(T2), E); + break; + } + case NodeKind::LineFoldEnd: + break; + default: + std::vector Expected { NodeKind::BlockStart, NodeKind::LineFoldEnd, NodeKind::Equals }; + if (TA == nullptr) { + // First tokens of TypeAssert + Expected.push_back(NodeKind::Colon); + // First tokens of Pattern + Expected.push_back(NodeKind::Identifier); + } + DE.add(File, T2, Expected); + } checkLineFoldEnd(); finish: - return new LetDeclaration( + return new FunctionDeclaration( Pub, - Let, - Mut, - Patt, + Fn, + Name, Params, TA, Body @@ -992,7 +1093,9 @@ finish: auto T0 = peekFirstTokenAfterModifiers(); switch (T0->getKind()) { case NodeKind::LetKeyword: - return parseLetDeclaration(); + return parseVariableDeclaration(); + case NodeKind::FnKeyword: + return parseFunctionDeclaration(); case NodeKind::ReturnKeyword: return parseReturnStatement(); case NodeKind::IfKeyword: @@ -1396,12 +1499,12 @@ next_member: Node* Parser::parseClassElement() { auto T0 = Tokens.peek(); switch (T0->getKind()) { - case NodeKind::LetKeyword: - return parseLetDeclaration(); + case NodeKind::FnKeyword: + return parseFunctionDeclaration(); case NodeKind::TypeKeyword: // TODO default: - DE.add(File, T0, std::vector { NodeKind::LetKeyword, NodeKind::TypeKeyword }); + DE.add(File, T0, std::vector { NodeKind::FnKeyword, NodeKind::TypeKeyword }); skipToLineFoldEnd(); return nullptr; } @@ -1411,7 +1514,9 @@ next_member: auto T0 = peekFirstTokenAfterModifiers(); switch (T0->getKind()) { case NodeKind::LetKeyword: - return parseLetDeclaration(); + return parseVariableDeclaration(); + case NodeKind::FnKeyword: + return parseFunctionDeclaration(); case NodeKind::IfKeyword: return parseIfStatement(); case NodeKind::ClassKeyword: diff --git a/src/Scanner.cc b/src/Scanner.cc index a1ea9ef2e..5f6b8a89d 100644 --- a/src/Scanner.cc +++ b/src/Scanner.cc @@ -62,6 +62,7 @@ namespace bolt { std::unordered_map Keywords = { { "pub", NodeKind::PubKeyword }, { "let", NodeKind::LetKeyword }, + { "fn", NodeKind::FnKeyword }, { "mut", NodeKind::MutKeyword }, { "return", NodeKind::ReturnKeyword }, { "type", NodeKind::TypeKeyword }, @@ -226,6 +227,8 @@ digit_finish: return new PubKeyword(StartLoc); case NodeKind::LetKeyword: return new LetKeyword(StartLoc); + case NodeKind::FnKeyword: + return new FnKeyword(StartLoc); case NodeKind::MutKeyword: return new MutKeyword(StartLoc); case NodeKind::TypeKeyword: