diff --git a/.vscode/launch.json b/.vscode/launch.json index 87e8e6c7b..a96936a12 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,7 +9,7 @@ "request": "launch", "name": "Debug", "program": "${workspaceFolder}/build/bolt", - "args": [ "test.bolt" ], + "args": [ "--direct-diagnostics", "verify", "test/checker/local_constraints_polymorphic_variable.bolt" ], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: build" } diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 5efc4e968..bdfda4e4e 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -95,6 +96,7 @@ namespace bolt { Dot, DotDot, Tilde, + At, LParen, RParen, LBracket, @@ -129,6 +131,8 @@ namespace bolt { IdentifierAlt, StringLiteral, IntegerLiteral, + ExpressionAnnotation, + TypeAssertAnnotation, TypeclassConstraintExpression, EqualityConstraintExpression, QualifiedTypeExpression, @@ -150,7 +154,7 @@ namespace bolt { MemberExpression, TupleExpression, NestedExpression, - ConstantExpression, + LiteralExpression, CallExpression, InfixExpression, PrefixExpression, @@ -221,7 +225,7 @@ namespace bolt { template<> bool is() const noexcept { return Kind == NodeKind::ReferenceExpression - || Kind == NodeKind::ConstantExpression + || Kind == NodeKind::LiteralExpression || Kind == NodeKind::PrefixExpression || Kind == NodeKind::InfixExpression || Kind == NodeKind::CallExpression @@ -423,6 +427,20 @@ namespace bolt { }; + class At : public Token { + public: + + inline At(TextLoc StartLoc): + Token(NodeKind::At, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::At; + } + + }; + class LParen : public Token { public: @@ -949,6 +967,11 @@ namespace bolt { return V; } + inline int asInt() const { + ZEN_ASSERT(V >= std::numeric_limits::min() && V <= std::numeric_limits::max()); + return V; + } + LiteralValue getValue() override; static bool classof(const Node* N) { @@ -957,6 +980,63 @@ namespace bolt { }; + class Annotation : public Node { + public: + + inline Annotation(NodeKind Kind): + Node(Kind) {} + + }; + + class ExpressionAnnotation : public Annotation { + public: + + At* At; + Expression* Expression; + + inline ExpressionAnnotation( + class At* At, + class Expression* Expression + ): Annotation(NodeKind::ExpressionAnnotation), + At(At), + Expression(Expression) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + inline class Expression* getExpression() const noexcept { + return Expression; + } + + }; + + class TypeExpression; + + class TypeAssertAnnotation : public Annotation { + public: + + At* At; + Colon* Colon; + TypeExpression* TE; + + inline TypeAssertAnnotation( + class At* At, + class Colon* Colon, + TypeExpression* TE + ): Annotation(NodeKind::TypeAssertAnnotation), + At(At), + Colon(Colon), + TE(TE) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + inline TypeExpression* getTypeExpression() const noexcept { + return TE; + } + + }; + class TypedNode : public Node { protected: @@ -1307,10 +1387,15 @@ namespace bolt { class Expression : public TypedNode { + public: + + std::vector Annotations; + protected: - inline Expression(NodeKind Kind): - TypedNode(Kind) {} + inline Expression(NodeKind Kind, std::vector Annotations = {}): + TypedNode(Kind), Annotations(Annotations) {} + }; @@ -1320,13 +1405,25 @@ namespace bolt { std::vector> ModulePath; Symbol* Name; - ReferenceExpression( + inline ReferenceExpression( std::vector> ModulePath, Symbol* Name ): Expression(NodeKind::ReferenceExpression), ModulePath(ModulePath), Name(Name) {} + inline ReferenceExpression( + std::vector Annotations, + std::vector> ModulePath, + Symbol* Name + ): Expression(NodeKind::ReferenceExpression, Annotations), + ModulePath(ModulePath), + Name(Name) {} + + inline ByteString getNameAsString() const noexcept { + return Name->getCanonicalText(); + } + Token* getFirstToken() const override; Token* getLastToken() const override; @@ -1376,6 +1473,18 @@ namespace bolt { BlockStart(BlockStart), Cases(Cases) {} + inline MatchExpression( + std::vector Annotations, + class MatchKeyword* MatchKeyword, + Expression* Value, + class BlockStart* BlockStart, + std::vector Cases + ): Expression(NodeKind::MatchExpression, Annotations), + MatchKeyword(MatchKeyword), + Value(Value), + BlockStart(BlockStart), + Cases(Cases) {} + Token* getFirstToken() const override; Token* getLastToken() const override; @@ -1397,6 +1506,16 @@ namespace bolt { Dot(Dot), Name(Name) {} + inline MemberExpression( + std::vector Annotations, + class Expression* E, + class Dot* Dot, + Token* Name + ): Expression(NodeKind::MemberExpression, Annotations), + E(E), + Dot(Dot), + Name(Name) {} + Token* getFirstToken() const override; Token* getLastToken() const override; @@ -1422,6 +1541,16 @@ namespace bolt { Elements(Elements), RParen(RParen) {} + inline TupleExpression( + std::vector Annotations, + class LParen* LParen, + std::vector> Elements, + class RParen* RParen + ): Expression(NodeKind::TupleExpression, Annotations), + LParen(LParen), + Elements(Elements), + RParen(RParen) {} + Token* getFirstToken() const override; Token* getLastToken() const override; @@ -1443,22 +1572,48 @@ namespace bolt { Inner(Inner), RParen(RParen) {} + inline NestedExpression( + std::vector Annotations, + class LParen* LParen, + Expression* Inner, + class RParen* RParen + ): Expression(NodeKind::NestedExpression, Annotations), + LParen(LParen), + Inner(Inner), + RParen(RParen) {} + Token* getFirstToken() const override; Token* getLastToken() const override; }; - class ConstantExpression : public Expression { + class LiteralExpression : public Expression { public: - class Literal* Token; + Literal* Token; - ConstantExpression( - class Literal* Token - ): Expression(NodeKind::ConstantExpression), + LiteralExpression( + Literal* Token + ): Expression(NodeKind::LiteralExpression), Token(Token) {} - class Token* getFirstToken() const override; + LiteralExpression( + std::vector Annotations, + Literal* Token + ): Expression(NodeKind::LiteralExpression, Annotations), + Token(Token) {} + + inline ByteString getAsText() { + ZEN_ASSERT(Token->getKind() == NodeKind::StringLiteral); + return static_cast(Token)->Text; + } + + inline int getAsInt() { + ZEN_ASSERT(Token->getKind() == NodeKind::IntegerLiteral); + return static_cast(Token)->asInt(); + } + + class Token* getFirstToken() const override; class Token* getLastToken() const override; }; @@ -1469,13 +1624,21 @@ namespace bolt { Expression* Function; std::vector Args; - CallExpression( + inline CallExpression( Expression* Function, std::vector Args ): Expression(NodeKind::CallExpression), Function(Function), Args(Args) {} + inline CallExpression( + std::vector Annotations, + Expression* Function, + std::vector Args + ): Expression(NodeKind::CallExpression, Annotations), + Function(Function), + Args(Args) {} + Token* getFirstToken() const override; Token* getLastToken() const override; @@ -1484,15 +1647,28 @@ namespace bolt { class InfixExpression : public Expression { public: - Expression* LHS; + Expression* Left; Token* Operator; - Expression* RHS; + Expression* Right; - InfixExpression(Expression* LHS, Token* Operator, Expression* RHS): - Expression(NodeKind::InfixExpression), - LHS(LHS), - Operator(Operator), - RHS(RHS) {} + inline InfixExpression( + Expression* Left, + Token* Operator, + Expression* Right + ): Expression(NodeKind::InfixExpression), + Left(Left), + Operator(Operator), + Right(Right) {} + + inline InfixExpression( + std::vector Annotations, + Expression* Left, + Token* Operator, + Expression* Right + ): Expression(NodeKind::InfixExpression, Annotations), + Left(Left), + Operator(Operator), + Right(Right) {} Token* getFirstToken() const override; Token* getLastToken() const override; @@ -2082,7 +2258,7 @@ namespace bolt { template<> inline NodeKind getNodeType() { return NodeKind::BindPattern; } template<> inline NodeKind getNodeType() { return NodeKind::ReferenceExpression; } template<> inline NodeKind getNodeType() { return NodeKind::NestedExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::ConstantExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::LiteralExpression; } template<> inline NodeKind getNodeType() { return NodeKind::CallExpression; } template<> inline NodeKind getNodeType() { return NodeKind::InfixExpression; } template<> inline NodeKind getNodeType() { return NodeKind::PrefixExpression; } diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index feee600d0..733b010bc 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -1,6 +1,7 @@ #pragma once +#include "CST.hpp" #include "zen/config.hpp" #include "bolt/CST.hpp" @@ -24,6 +25,7 @@ namespace bolt { BOLT_GEN_CASE(Dot) BOLT_GEN_CASE(DotDot) BOLT_GEN_CASE(Tilde) + BOLT_GEN_CASE(At) BOLT_GEN_CASE(LParen) BOLT_GEN_CASE(RParen) BOLT_GEN_CASE(LBracket) @@ -58,6 +60,8 @@ namespace bolt { BOLT_GEN_CASE(IdentifierAlt) BOLT_GEN_CASE(StringLiteral) BOLT_GEN_CASE(IntegerLiteral) + BOLT_GEN_CASE(ExpressionAnnotation) + BOLT_GEN_CASE(TypeAssertAnnotation) BOLT_GEN_CASE(TypeclassConstraintExpression) BOLT_GEN_CASE(EqualityConstraintExpression) BOLT_GEN_CASE(QualifiedTypeExpression) @@ -79,7 +83,7 @@ namespace bolt { BOLT_GEN_CASE(MemberExpression) BOLT_GEN_CASE(TupleExpression) BOLT_GEN_CASE(NestedExpression) - BOLT_GEN_CASE(ConstantExpression) + BOLT_GEN_CASE(LiteralExpression) BOLT_GEN_CASE(CallExpression) BOLT_GEN_CASE(InfixExpression) BOLT_GEN_CASE(PrefixExpression) @@ -112,371 +116,387 @@ namespace bolt { } void visitToken(Token* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitEquals(Equals* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitColon(Colon* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitComma(Comma* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitDot(Dot* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitDotDot(DotDot* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitTilde(Tilde* N) { - visitToken(N); + static_cast(this)->visitToken(N); + } + + void visitAt(At* N) { + static_cast(this)->visitToken(N); } void visitLParen(LParen* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitRParen(RParen* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitLBracket(LBracket* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitRBracket(RBracket* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitLBrace(LBrace* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitRBrace(RBrace* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitRArrow(RArrow* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitRArrowAlt(RArrowAlt* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitLetKeyword(LetKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitForeignKeyword(ForeignKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitMutKeyword(MutKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitPubKeyword(PubKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitTypeKeyword(TypeKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitReturnKeyword(ReturnKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitModKeyword(ModKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitStructKeyword(StructKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitEnumKeyword(EnumKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitClassKeyword(ClassKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitInstanceKeyword(InstanceKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitElifKeyword(ElifKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitIfKeyword(IfKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitElseKeyword(ElseKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitMatchKeyword(MatchKeyword* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitInvalid(Invalid* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitEndOfFile(EndOfFile* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitBlockStart(BlockStart* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitBlockEnd(BlockEnd* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitLineFoldEnd(LineFoldEnd* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitCustomOperator(CustomOperator* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitAssignment(Assignment* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitIdentifier(Identifier* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitIdentifierAlt(IdentifierAlt* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitStringLiteral(StringLiteral* N) { - visitToken(N); + static_cast(this)->visitToken(N); } void visitIntegerLiteral(IntegerLiteral* N) { - visitToken(N); + static_cast(this)->visitToken(N); + } + + void visitAnnotation(Annotation* N) { + static_cast(this)->visitNode(N); + } + + void visitTypeAssertAnnotation(TypeAssertAnnotation* N) { + static_cast(this)->visitAnnotation(N); + } + + void visitExpressionAnnotation(ExpressionAnnotation* N) { + static_cast(this)->visitAnnotation(N); } void visitConstraintExpression(ConstraintExpression* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitTypeclassConstraintExpression(TypeclassConstraintExpression* N) { - visitConstraintExpression(N); + static_cast(this)->visitConstraintExpression(N); } void visitEqualityConstraintExpression(EqualityConstraintExpression* N) { - visitConstraintExpression(N); + static_cast(this)->visitConstraintExpression(N); } void visitTypeExpression(TypeExpression* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitQualifiedTypeExpression(QualifiedTypeExpression* N) { - visitTypeExpression(N); + static_cast(this)->visitTypeExpression(N); } void visitReferenceTypeExpression(ReferenceTypeExpression* N) { - visitTypeExpression(N); + static_cast(this)->visitTypeExpression(N); } void visitArrowTypeExpression(ArrowTypeExpression* N) { - visitTypeExpression(N); + static_cast(this)->visitTypeExpression(N); } void visitAppTypeExpression(AppTypeExpression* N) { - visitTypeExpression(N); + static_cast(this)->visitTypeExpression(N); } void visitVarTypeExpression(VarTypeExpression* N) { - visitTypeExpression(N); + static_cast(this)->visitTypeExpression(N); } void visitNestedTypeExpression(NestedTypeExpression* N) { - visitTypeExpression(N); + static_cast(this)->visitTypeExpression(N); } void visitTupleTypeExpression(TupleTypeExpression* N) { - visitTypeExpression(N); + static_cast(this)->visitTypeExpression(N); } void visitPattern(Pattern* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitBindPattern(BindPattern* N) { - visitPattern(N); + static_cast(this)->visitPattern(N); } void visitLiteralPattern(LiteralPattern* N) { - visitPattern(N); + static_cast(this)->visitPattern(N); } void visitNamedPattern(NamedPattern* N) { - visitPattern(N); + static_cast(this)->visitPattern(N); } void visitTuplePattern(TuplePattern* N) { - visitPattern(N); + static_cast(this)->visitPattern(N); } void visitNestedPattern(NestedPattern* N) { - visitPattern(N); + static_cast(this)->visitPattern(N); } void visitListPattern(ListPattern* N) { - visitPattern(N); + static_cast(this)->visitPattern(N); } void visitExpression(Expression* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitReferenceExpression(ReferenceExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } void visitMatchCase(MatchCase* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitMatchExpression(MatchExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } void visitMemberExpression(MemberExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } void visitTupleExpression(TupleExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } void visitNestedExpression(NestedExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } - void visitConstantExpression(ConstantExpression* N) { - visitExpression(N); + void visitLiteralExpression(LiteralExpression* N) { + static_cast(this)->visitExpression(N); } void visitCallExpression(CallExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } void visitInfixExpression(InfixExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } void visitPrefixExpression(PrefixExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } void visitRecordExpressionField(RecordExpressionField* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitRecordExpression(RecordExpression* N) { - visitExpression(N); + static_cast(this)->visitExpression(N); } void visitStatement(Statement* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitExpressionStatement(ExpressionStatement* N) { - visitStatement(N); + static_cast(this)->visitStatement(N); } void visitReturnStatement(ReturnStatement* N) { - visitStatement(N); + static_cast(this)->visitStatement(N); } void visitIfStatement(IfStatement* N) { - visitStatement(N); + static_cast(this)->visitStatement(N); } void visitIfStatementPart(IfStatementPart* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitTypeAssert(TypeAssert* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitParameter(Parameter* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitLetBody(LetBody* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitLetBlockBody(LetBlockBody* N) { - visitLetBody(N); + static_cast(this)->visitLetBody(N); } void visitLetExprBody(LetExprBody* N) { - visitLetBody(N); + static_cast(this)->visitLetBody(N); } void visitLetDeclaration(LetDeclaration* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitRecordDeclarationField(RecordDeclarationField* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitRecordDeclaration(RecordDeclaration* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitVariantDeclaration(VariantDeclaration* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitVariantDeclarationMember(VariantDeclarationMember* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitTupleVariantDeclarationMember(TupleVariantDeclarationMember* N) { - visitVariantDeclarationMember(N); + static_cast(this)->visitVariantDeclarationMember(N); } void visitRecordVariantDeclarationMember(RecordVariantDeclarationMember* N) { - visitVariantDeclarationMember(N); + static_cast(this)->visitVariantDeclarationMember(N); } void visitClassDeclaration(ClassDeclaration* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitInstanceDeclaration(InstanceDeclaration* N) { - visitNode(N); + static_cast(this)->visitNode(N); } void visitSourceFile(SourceFile* N) { - visitNode(N); + static_cast(this)->visitNode(N); } public: @@ -495,6 +515,7 @@ namespace bolt { BOLT_GEN_CHILD_CASE(Dot) BOLT_GEN_CHILD_CASE(DotDot) BOLT_GEN_CHILD_CASE(Tilde) + BOLT_GEN_CHILD_CASE(At) BOLT_GEN_CHILD_CASE(LParen) BOLT_GEN_CHILD_CASE(RParen) BOLT_GEN_CHILD_CASE(LBracket) @@ -529,6 +550,8 @@ namespace bolt { BOLT_GEN_CHILD_CASE(IdentifierAlt) BOLT_GEN_CHILD_CASE(StringLiteral) BOLT_GEN_CHILD_CASE(IntegerLiteral) + BOLT_GEN_CHILD_CASE(ExpressionAnnotation) + BOLT_GEN_CHILD_CASE(TypeAssertAnnotation) BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression) BOLT_GEN_CHILD_CASE(EqualityConstraintExpression) BOLT_GEN_CHILD_CASE(QualifiedTypeExpression) @@ -550,7 +573,7 @@ namespace bolt { BOLT_GEN_CHILD_CASE(MemberExpression) BOLT_GEN_CHILD_CASE(TupleExpression) BOLT_GEN_CHILD_CASE(NestedExpression) - BOLT_GEN_CHILD_CASE(ConstantExpression) + BOLT_GEN_CHILD_CASE(LiteralExpression) BOLT_GEN_CHILD_CASE(CallExpression) BOLT_GEN_CHILD_CASE(InfixExpression) BOLT_GEN_CHILD_CASE(PrefixExpression) @@ -596,6 +619,9 @@ namespace bolt { void visitEachChild(Tilde* N) { } + void visitEachChild(At* N) { + } + void visitEachChild(LParen* N) { } @@ -698,6 +724,17 @@ namespace bolt { void visitEachChild(IntegerLiteral* N) { } + void visitEachChild(ExpressionAnnotation* N) { + BOLT_VISIT(N->At); + BOLT_VISIT(N->Expression); + } + + void visitEachChild(TypeAssertAnnotation* N) { + BOLT_VISIT(N->At); + BOLT_VISIT(N->Colon); + BOLT_VISIT(N->TE); + } + void visitEachChild(TypeclassConstraintExpression* N) { BOLT_VISIT(N->Name); for (auto TE: N->TEs) { @@ -809,6 +846,9 @@ namespace bolt { } void visitEachChild(ReferenceExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } for (auto [Name, Dot]: N->ModulePath) { BOLT_VISIT(Name); BOLT_VISIT(Dot); @@ -823,6 +863,9 @@ namespace bolt { } void visitEachChild(MatchExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } BOLT_VISIT(N->MatchKeyword); if (N->Value) { BOLT_VISIT(N->Value); @@ -834,12 +877,18 @@ namespace bolt { } void visitEachChild(MemberExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } BOLT_VISIT(N->getExpression()); BOLT_VISIT(N->Dot); BOLT_VISIT(N->Name); } void visitEachChild(TupleExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } BOLT_VISIT(N->LParen); for (auto [E, Comma]: N->Elements) { BOLT_VISIT(E); @@ -851,16 +900,25 @@ namespace bolt { } void visitEachChild(NestedExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } BOLT_VISIT(N->LParen); BOLT_VISIT(N->Inner); BOLT_VISIT(N->RParen); } - void visitEachChild(ConstantExpression* N) { + void visitEachChild(LiteralExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } BOLT_VISIT(N->Token); } void visitEachChild(CallExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } BOLT_VISIT(N->Function); for (auto Arg: N->Args) { BOLT_VISIT(Arg); @@ -868,12 +926,18 @@ namespace bolt { } void visitEachChild(InfixExpression* N) { - BOLT_VISIT(N->LHS); + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + BOLT_VISIT(N->Left); BOLT_VISIT(N->Operator); - BOLT_VISIT(N->RHS); + BOLT_VISIT(N->Right); } void visitEachChild(PrefixExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } BOLT_VISIT(N->Operator); BOLT_VISIT(N->Argument); } @@ -885,6 +949,9 @@ namespace bolt { } void visitEachChild(RecordExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } BOLT_VISIT(N->LBrace); for (auto [Field, Comma]: N->Fields) { BOLT_VISIT(Field); diff --git a/include/bolt/DiagnosticEngine.hpp b/include/bolt/DiagnosticEngine.hpp index 272a0b35e..d3f73bdf5 100644 --- a/include/bolt/DiagnosticEngine.hpp +++ b/include/bolt/DiagnosticEngine.hpp @@ -54,6 +54,8 @@ namespace bolt { Diagnostics.clear(); } + void sort(); + std::size_t countDiagnostics() const noexcept { return Diagnostics.size(); } @@ -217,6 +219,7 @@ namespace bolt { void write(const std::string_view& S); void write(std::size_t N); + void write(char C); public: diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index e89b3b596..8e2067da0 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -1,11 +1,8 @@ #pragma once -#include #include -#include #include -#include #include "bolt/ByteString.hpp" #include "bolt/String.hpp" @@ -21,7 +18,6 @@ namespace bolt { UnificationError, TypeclassMissing, InstanceNotFound, - ClassNotFound, TupleIndexOutOfRange, InvalidTypeToTypeclass, FieldNotFound, @@ -45,6 +41,23 @@ namespace bolt { return nullptr; } + virtual unsigned getCode() const noexcept = 0; + + }; + + class UnexpectedStringDiagnostic : public Diagnostic { + public: + + TextFile& File; + TextLoc Location; + String Actual; + + inline UnexpectedStringDiagnostic(TextFile& File, TextLoc Location, String Actual): + Diagnostic(DiagnosticKind::UnexpectedString), File(File), Location(Location), Actual(Actual) {} + + unsigned getCode() const noexcept override { + return 1001; + } }; @@ -58,17 +71,9 @@ namespace bolt { inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector Expected): Diagnostic(DiagnosticKind::UnexpectedToken), File(File), Actual(Actual), Expected(Expected) {} - }; - - class UnexpectedStringDiagnostic : public Diagnostic { - public: - - TextFile& File; - TextLoc Location; - String Actual; - - inline UnexpectedStringDiagnostic(TextFile& File, TextLoc Location, String Actual): - Diagnostic(DiagnosticKind::UnexpectedString), File(File), Location(Location), Actual(Actual) {} + unsigned getCode() const noexcept override { + return 1101; + } }; @@ -85,6 +90,10 @@ namespace bolt { return Initiator; } + unsigned getCode() const noexcept override { + return 2005; + } + }; class UnificationErrorDiagnostic : public Diagnostic { @@ -111,6 +120,10 @@ namespace bolt { return Source; } + unsigned getCode() const noexcept override { + return 2010; + } + }; class TypeclassMissingDiagnostic : public Diagnostic { @@ -126,6 +139,10 @@ namespace bolt { return Decl; } + unsigned getCode() const noexcept override { + return 2201; + } + }; class InstanceNotFoundDiagnostic : public Diagnostic { @@ -142,15 +159,9 @@ namespace bolt { return Source; } - }; - - class ClassNotFoundDiagnostic : public Diagnostic { - public: - - ByteString Name; - - inline ClassNotFoundDiagnostic(ByteString Name): - Diagnostic(DiagnosticKind::ClassNotFound), Name(Name) {} + unsigned getCode() const noexcept override { + return 2251; + } }; @@ -163,6 +174,10 @@ namespace bolt { inline TupleIndexOutOfRangeDiagnostic(TTuple* Tuple, std::size_t I): Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I) {} + unsigned getCode() const noexcept override { + return 2015; + } + }; class InvalidTypeToTypeclassDiagnostic : public Diagnostic { @@ -179,6 +194,10 @@ namespace bolt { return Source; } + unsigned getCode() const noexcept override { + return 2060; + } + }; class FieldNotFoundDiagnostic : public Diagnostic { @@ -192,6 +211,10 @@ namespace bolt { inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source): Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {} + unsigned getCode() const noexcept override { + return 2017; + } + }; } diff --git a/include/bolt/Parser.hpp b/include/bolt/Parser.hpp index e28f2e5a5..9e24651d6 100644 --- a/include/bolt/Parser.hpp +++ b/include/bolt/Parser.hpp @@ -94,6 +94,8 @@ namespace bolt { VarTypeExpression* parseVarTypeExpression(); ReferenceTypeExpression* parseReferenceTypeExpression(); + std::vector parseAnnotations(); + void checkLineFoldEnd(); void skipToLineFoldEnd(); diff --git a/include/bolt/Scanner.hpp b/include/bolt/Scanner.hpp index 3acbbd031..ba16e276c 100644 --- a/include/bolt/Scanner.hpp +++ b/include/bolt/Scanner.hpp @@ -13,9 +13,12 @@ namespace bolt { class Token; + class DiagnosticEngine; class Scanner : public BufferedStream { + DiagnosticEngine& DE; + TextFile& File; Stream& Chars; @@ -41,13 +44,17 @@ namespace bolt { return Chars.peek(Offset); } + std::string scanIdentifier(); + + Token* readNullable(); + protected: Token* read() override; public: - Scanner(TextFile& File, Stream& Chars); + Scanner(DiagnosticEngine& DE, TextFile& File, Stream& Chars); }; diff --git a/src/CST.cc b/src/CST.cc index 2126333b3..c33a87df4 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -345,6 +345,22 @@ namespace bolt { return true; } + Token* ExpressionAnnotation::getFirstToken() const { + return At; + } + + Token* ExpressionAnnotation::getLastToken() const { + return Expression->getLastToken(); + } + + Token* TypeAssertAnnotation::getFirstToken() const { + return At; + } + + Token* TypeAssertAnnotation::getLastToken() const { + return TE->getLastToken(); + } + Token* TypeclassConstraintExpression::getFirstToken() const { return Name; } @@ -553,11 +569,11 @@ namespace bolt { return RParen; } - Token* ConstantExpression::getFirstToken() const { + Token* LiteralExpression::getFirstToken() const { return Token; } - Token* ConstantExpression::getLastToken() const { + Token* LiteralExpression::getLastToken() const { return Token; } @@ -573,11 +589,11 @@ namespace bolt { } Token* InfixExpression::getFirstToken() const { - return LHS->getFirstToken(); + return Left->getFirstToken(); } Token* InfixExpression::getLastToken() const { - return RHS->getLastToken(); + return Right->getLastToken(); } Token* PrefixExpression::getFirstToken() const { @@ -938,6 +954,10 @@ namespace bolt { return "~"; } + std::string At::getText() const { + return "@"; + } + std::string ClassKeyword::getText() const { return "class"; } diff --git a/src/Checker.cc b/src/Checker.cc index 636ac9859..3f143a81b 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -904,6 +904,12 @@ namespace bolt { Type* Ty; + for (auto A: X->Annotations) { + if (A->getKind() == NodeKind::TypeAssertAnnotation) { + inferTypeExpression(static_cast(A)->TE); + } + } + switch (X->getKind()) { case NodeKind::MatchExpression: @@ -942,9 +948,9 @@ namespace bolt { break; } - case NodeKind::ConstantExpression: + case NodeKind::LiteralExpression: { - auto Const = static_cast(X); + auto Const = static_cast(X); Ty = inferLiteral(Const->Token); break; } @@ -1009,8 +1015,8 @@ namespace bolt { auto OpTy = instantiate(Scm, Infix->Operator); Ty = createTypeVar(); std::vector ArgTys; - ArgTys.push_back(inferExpression(Infix->LHS)); - ArgTys.push_back(inferExpression(Infix->RHS)); + ArgTys.push_back(inferExpression(Infix->Left)); + ArgTys.push_back(inferExpression(Infix->Right)); makeEqual(TArrow::build(ArgTys, Ty), OpTy, X); break; } diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 5937e7a33..99c7f283f 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -48,6 +48,25 @@ namespace bolt { Diagnostic::Diagnostic(DiagnosticKind Kind): std::runtime_error("a compiler error occurred without being caught"), Kind(Kind) {} + bool sourceLocLessThan(const Diagnostic* L, const Diagnostic* R) { + auto N1 = L->getNode(); + auto N2 = R->getNode(); + if (N1 == nullptr && N2 == nullptr) { + return false; + } + if (N1 == nullptr) { + return true; + } + if (N2 == nullptr) { + return false; + } + return N1->getStartLine() < N2->getStartLine() || N1->getStartColumn() < N2->getStartColumn(); + }; + + void DiagnosticStore::sort() { + std::sort(Diagnostics.begin(), Diagnostics.end(), sourceLocLessThan); + } + static std::string describe(NodeKind Type) { switch (Type) { case NodeKind::Identifier: @@ -122,7 +141,7 @@ namespace bolt { return "a function or variable reference"; case NodeKind::MatchExpression: return "a match-expression"; - case NodeKind::ConstantExpression: + case NodeKind::LiteralExpression: return "a literal expression"; case NodeKind::MemberExpression: return "an accessor of a member"; @@ -132,6 +151,8 @@ namespace bolt { return "a branch of an if-statement"; case NodeKind::ListPattern: return "a list pattern"; + case NodeKind::TypeAssertAnnotation: + return "an annotation for a type assertion"; default: ZEN_UNREACHABLE } @@ -478,6 +499,10 @@ namespace bolt { Out << S; } + void ConsoleDiagnostics::write(char C) { + Out << C; + } + void ConsoleDiagnostics::write(std::size_t I) { Out << I; } @@ -849,16 +874,6 @@ namespace bolt { break; } - case DiagnosticKind::ClassNotFound: - { - auto E = static_cast(D); - writePrefix(E); - write("the type class "); - writeTypeclassName(E.Name); - write(" was not found.\n\n"); - break; - } - case DiagnosticKind::TupleIndexOutOfRange: { auto E = static_cast(D); diff --git a/src/Evaluator.cc b/src/Evaluator.cc index 3ceffb252..a63312a2b 100644 --- a/src/Evaluator.cc +++ b/src/Evaluator.cc @@ -16,9 +16,9 @@ namespace bolt { // ZEN_ASSERT(Decl && Decl->getKind() == NodeKind::FunctionDeclaration); // return static_cast(Decl); } - case NodeKind::ConstantExpression: + case NodeKind::LiteralExpression: { - auto CE = static_cast(X); + auto CE = static_cast(X); switch (CE->Token->getKind()) { case NodeKind::IntegerLiteral: return static_cast(CE->Token)->V; diff --git a/src/Parser.cc b/src/Parser.cc index 56e2d08b5..7215e20e2 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -609,6 +609,7 @@ after_tuple_element: } Expression* Parser::parsePrimitiveExpression() { + auto Annotations = parseAnnotations(); auto T0 = Tokens.peek(); switch (T0->getKind()) { case NodeKind::Identifier: @@ -634,7 +635,7 @@ after_tuple_element: DE.add(File, T3, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt }); return nullptr; } - return new ReferenceExpression(ModulePath, static_cast(T3)); + return new ReferenceExpression(Annotations, ModulePath, static_cast(T3)); } case NodeKind::LParen: { @@ -687,21 +688,29 @@ after_tuple_element: } after_tuple_elements: if (Elements.size() == 1 && !std::get<1>(Elements.front())) { - return new NestedExpression(LParen, std::get<0>(Elements.front()), RParen); + return new NestedExpression(Annotations, LParen, std::get<0>(Elements.front()), RParen); } - return new TupleExpression { LParen, Elements, RParen }; + return new TupleExpression { Annotations, LParen, Elements, RParen }; } case NodeKind::MatchKeyword: return parseMatchExpression(); case NodeKind::IntegerLiteral: case NodeKind::StringLiteral: Tokens.get(); - return new ConstantExpression(static_cast(T0)); + return new LiteralExpression(Annotations, static_cast(T0)); case NodeKind::LBrace: return parseRecordExpression(); default: // Tokens.get(); - DE.add(File, T0, std::vector { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen, NodeKind::LBrace, NodeKind::IntegerLiteral, NodeKind::StringLiteral }); + DE.add(File, T0, std::vector { + NodeKind::MatchKeyword, + NodeKind::Identifier, + NodeKind::IdentifierAlt, + NodeKind::LParen, + NodeKind::LBrace, + NodeKind::IntegerLiteral, + NodeKind::StringLiteral + }); return nullptr; } } @@ -722,7 +731,8 @@ after_tuple_elements: case NodeKind::Identifier: Tokens.get(); Tokens.get(); - E = new MemberExpression { E, static_cast(T1), T2 }; + E = new MemberExpression { E->Annotations, E, static_cast(T1), T2 }; + E->Annotations.clear(); break; default: goto finish; @@ -761,7 +771,9 @@ finish: if (Args.empty()) { return Operator; } - return new CallExpression(Operator, Args); + auto Annotations = Operator->Annotations; + Operator->Annotations.clear(); + return new CallExpression(Annotations, Operator, Args); } Expression* Parser::parseUnaryExpression() { @@ -1518,6 +1530,54 @@ next_member: return new SourceFile(File, Elements); } + std::vector Parser::parseAnnotations() { + std::vector Annotations; + for (;;) { + auto T0 = Tokens.peek(); + if (T0->getKind() != NodeKind::At) { + break; + } + auto At = static_cast(T0); + Tokens.get(); + auto T1 = Tokens.peek(); + switch (T1->getKind()) { + case NodeKind::Colon: + { + auto Colon = static_cast(T1); + Tokens.get(); + auto TE = parsePrimitiveTypeExpression(); + if (!TE) { + // TODO + continue; + } + Annotations.push_back(new TypeAssertAnnotation { At, Colon, TE }); + continue; + } + default: + { + // auto Name = static_cast(T1); + // Tokens.get(); + auto E = parseExpression(); + if (!E) { + At->unref(); + skipToLineFoldEnd(); + continue; + } + checkLineFoldEnd(); + Annotations.push_back(new ExpressionAnnotation { At, E }); + break; + } + // default: + // DE.add(File, T1, std::vector { NodeKind::Colon, NodeKind::Identifier }); + // At->unref(); + // skipToLineFoldEnd(); + // break; + } +next_annotation:; + } + return Annotations; + } + void Parser::skipToLineFoldEnd() { unsigned Level = 0; for (;;) { diff --git a/src/Scanner.cc b/src/Scanner.cc index f6f3ab461..61f4e8649 100644 --- a/src/Scanner.cc +++ b/src/Scanner.cc @@ -9,6 +9,7 @@ #include "bolt/Integer.hpp" #include "bolt/CST.hpp" #include "bolt/Diagnostics.hpp" +#include "bolt/DiagnosticEngine.hpp" #include "bolt/Scanner.hpp" namespace bolt { @@ -47,6 +48,12 @@ namespace bolt { } } + static bool isDirectiveIdentifierStart(Char Chr) { + return (Chr >= 65 && Chr <= 90) // Uppercase letter + || (Chr >= 96 && Chr <= 122) // Lowercase letter + || Chr == '_'; + } + static bool isIdentifierPart(Char Chr) { return (Chr >= 65 && Chr <= 90) // Uppercase letter || (Chr >= 96 && Chr <= 122) // Lowercase letter @@ -77,10 +84,29 @@ namespace bolt { { "enum", NodeKind::EnumKeyword }, }; - Scanner::Scanner(TextFile& File, Stream& Chars): - File(File), Chars(Chars) {} + Scanner::Scanner(DiagnosticEngine& DE, TextFile& File, Stream& Chars): + DE(DE), File(File), Chars(Chars) {} - Token* Scanner::read() { + std::string Scanner::scanIdentifier() { + auto Loc = getCurrentLoc(); + auto C0 = getChar(); + if (!isDirectiveIdentifierStart(C0)) { + DE.add(File, Loc, std::string { C0 }); + return nullptr; + } + ByteString Text { static_cast(C0) }; + for (;;) { + auto C1 = peekChar(); + if (!isIdentifierPart(C1)) { + break; + } + Text.push_back(C1); + getChar(); + } + return Text; +} + + Token* Scanner::readNullable() { TextLoc StartLoc; Char C0; @@ -92,6 +118,23 @@ namespace bolt { continue; } if (C0 == '#') { + getChar(); + auto C1 = peekChar(0); + auto C2 = peekChar(1); + if (C1 == '!' && C2 == '!') { + getChar(); + getChar(); + auto Name = scanIdentifier(); + std::string Value; + for (;;) { + C0 = getChar(); + Value.push_back(C0); + if (C0 == '\n' || C0 == EOF) { + break; + } + } + continue; + } for (;;) { C0 = getChar(); if (C0 == '\n' || C0 == EOF) { @@ -278,7 +321,8 @@ digit_finish: case '\'': Text.push_back('\''); break; case '"': Text.push_back('"'); break; default: - throw UnexpectedStringDiagnostic(File, Loc, String { static_cast(C1) }); + DE.add(File, Loc, String { static_cast(C1) }); + return nullptr; } Escaping = false; } else { @@ -305,7 +349,8 @@ after_string_contents: getChar(); auto C2 = peekChar(); if (C2 == '.') { - throw UnexpectedStringDiagnostic(File, getCurrentLoc(), String { static_cast(C2) }); + DE.add(File, getCurrentLoc(), String { static_cast(C2) }); + return nullptr; } return new DotDot(StartLoc); } @@ -347,7 +392,6 @@ after_string_contents: } return new CustomOperator(Text, StartLoc); } - #define BOLT_SIMPLE_TOKEN(ch, name) case ch: return new name(StartLoc); @@ -360,14 +404,26 @@ after_string_contents: BOLT_SIMPLE_TOKEN('{', LBrace) BOLT_SIMPLE_TOKEN('}', RBrace) BOLT_SIMPLE_TOKEN('~', Tilde) + BOLT_SIMPLE_TOKEN('@', At) default: - throw UnexpectedStringDiagnostic(File, StartLoc, String { static_cast(C0) }); + DE.add(File, StartLoc, String { static_cast(C0) }); + return nullptr; } } + Token* Scanner::read() { + for (;;) { + auto T0 = readNullable(); + if (T0) { + // EndOFFile is guaranteed to be produced, so that ends the stream. + return T0; + } + } + } + Punctuator::Punctuator(Stream& Tokens): Tokens(Tokens) { Frames.push(FrameType::Block); diff --git a/src/Types.cc b/src/Types.cc index 9e2c19e75..a7d2ac52b 100644 --- a/src/Types.cc +++ b/src/Types.cc @@ -306,118 +306,99 @@ namespace bolt { ZEN_UNREACHABLE } - // bool Type::operator==(const Type& Other) const noexcept { - // switch (Kind) { - // case TypeKind::Var: - // if (Other.Kind != TypeKind::Var) { - // return false; - // } - // return static_cast(this)->Id == static_cast(Other).Id; - // case TypeKind::Tuple: - // { - // if (Other.Kind != TypeKind::Tuple) { - // return false; - // } - // auto A = static_cast(*this); - // auto B = static_cast(Other); - // if (A.ElementTypes.size() != B.ElementTypes.size()) { - // return false; - // } - // for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) { - // if (*T1 != *T2) { - // return false; - // } - // } - // return true; - // } - // case TypeKind::TupleIndex: - // { - // if (Other.Kind != TypeKind::TupleIndex) { - // return false; - // } - // auto A = static_cast(*this); - // auto B = static_cast(Other); - // return A.I == B.I && *A.Ty == *B.Ty; - // } - // case TypeKind::Con: - // { - // if (Other.Kind != TypeKind::Con) { - // return false; - // } - // auto A = static_cast(*this); - // auto B = static_cast(Other); - // if (A.Id != B.Id) { - // return false; - // } - // if (A.Args.size() != B.Args.size()) { - // return false; - // } - // for (auto [T1, T2]: zen::zip(A.Args, B.Args)) { - // if (*T1 != *T2) { - // return false; - // } - // } - // return true; - // } - // case TypeKind::Arrow: - // { - // if (Other.Kind != TypeKind::Arrow) { - // return false; - // } - // auto A = static_cast(*this); - // auto B = static_cast(Other); - // /* ArrowCursor C1 { &A }; */ - // /* ArrowCursor C2 { &B }; */ - // /* for (;;) { */ - // /* auto T1 = C1.next(); */ - // /* auto T2 = C2.next(); */ - // /* if (T1 == nullptr && T2 == nullptr) { */ - // /* break; */ - // /* } */ - // /* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */ - // /* return false; */ - // /* } */ - // /* } */ - // if (A.ParamTypes.size() != B.ParamTypes.size()) { - // return false; - // } - // for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) { - // if (*T1 != *T2) { - // return false; - // } - // } - // return A.ReturnType != B.ReturnType; - // } - // case TypeKind::Absent: - // if (Other.Kind != TypeKind::Absent) { - // return false; - // } - // return true; - // case TypeKind::Nil: - // if (Other.Kind != TypeKind::Nil) { - // return false; - // } - // return true; - // case TypeKind::Present: - // { - // if (Other.Kind != TypeKind::Present) { - // return false; - // } - // auto A = static_cast(*this); - // auto B = static_cast(Other); - // return *A.Ty == *B.Ty; - // } - // case TypeKind::Field: - // { - // if (Other.Kind != TypeKind::Field) { - // return false; - // } - // auto A = static_cast(*this); - // auto B = static_cast(Other); - // return *A.Ty == *B.Ty && *A.RestTy == *B.RestTy; - // } - // } - // } + bool Type::operator==(const Type& Other) const noexcept { + switch (Kind) { + case TypeKind::Var: + if (Other.Kind != TypeKind::Var) { + return false; + } + return static_cast(this)->Id == static_cast(Other).Id; + case TypeKind::Tuple: + { + if (Other.Kind != TypeKind::Tuple) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + if (A.ElementTypes.size() != B.ElementTypes.size()) { + return false; + } + for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) { + if (*T1 != *T2) { + return false; + } + } + return true; + } + case TypeKind::TupleIndex: + { + if (Other.Kind != TypeKind::TupleIndex) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + return A.I == B.I && *A.Ty == *B.Ty; + } + case TypeKind::Con: + { + if (Other.Kind != TypeKind::Con) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + if (A.Id != B.Id) { + return false; + } + return true; + } + case TypeKind::App: + { + if (Other.Kind != TypeKind::App) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + return *A.Op == *B.Op && *A.Arg == *B.Arg; + } + case TypeKind::Arrow: + { + if (Other.Kind != TypeKind::Arrow) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + return *A.ParamType == *B.ParamType && *A.ReturnType == *B.ReturnType; + } + case TypeKind::Absent: + if (Other.Kind != TypeKind::Absent) { + return false; + } + return true; + case TypeKind::Nil: + if (Other.Kind != TypeKind::Nil) { + return false; + } + return true; + case TypeKind::Present: + { + if (Other.Kind != TypeKind::Present) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + return *A.Ty == *B.Ty; + } + case TypeKind::Field: + { + if (Other.Kind != TypeKind::Field) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + return A.Name == B.Name && *A.Ty == *B.Ty && *A.RestTy == *B.RestTy; + } + } + } TypeIterator Type::begin() { return TypeIterator { this, getStartIndex() }; diff --git a/src/main.cc b/src/main.cc index b63f2a5ed..812899caf 100644 --- a/src/main.cc +++ b/src/main.cc @@ -4,11 +4,13 @@ #include #include #include +#include #include "zen/config.hpp" #include "zen/po.hpp" #include "bolt/CST.hpp" +#include "bolt/CSTVisitor.hpp" #include "bolt/DiagnosticEngine.hpp" #include "bolt/Diagnostics.hpp" #include "bolt/Scanner.hpp" @@ -38,10 +40,14 @@ namespace po = zen::po; int main(int Argc, const char* Argv[]) { auto Match = po::program("bolt", "The offical compiler for the Bolt programming language") + .flag(po::flag("additional-syntax", "Enable additional Bolt syntax for asserting compiler state")) .flag(po::flag("direct-diagnostics", "Immediately print diagnostics without sorting them first")) // TODO support default values in zen::po .subcommand( po::command("check", "Check sources for programming mistakes") .pos_arg("file", po::some)) + .subcommand( + po::command("verify", "Verify integrity of the compiler on selected file(s)") + .pos_arg("file", po::some)) .subcommand( po::command("eval", "Run sources") .pos_arg("file", po::some) @@ -51,10 +57,12 @@ int main(int Argc, const char* Argv[]) { ZEN_ASSERT(Match.has_subcommand()); - auto DirectDiagnostics = Match.has_flag("direct-diagnostics") && Match.get_flag("direct-diagnostics"); - auto [Name, Submatch] = Match.subcommand(); + auto IsVerify = Name == "verify"; + auto DirectDiagnostics = Match.has_flag("direct-diagnostics") && Match.get_flag("direct-diagnostics") && !IsVerify; + auto AdditionalSyntax = Match.has_flag("additional-syntax") && Match.get_flag("additional-syntax"); + ConsoleDiagnostics DE; LanguageConfig Config; @@ -65,7 +73,7 @@ int main(int Argc, const char* Argv[]) { auto Text = readFile(Filename); TextFile File { Filename, Text }; VectorStream Chars(Text, EOF); - Scanner S(File, Chars); + Scanner S(DE, File, Chars); Punctuator PT(S); Parser P(File, PT, DE); @@ -86,28 +94,78 @@ int main(int Argc, const char* Argv[]) { TheChecker.check(SF); } - auto lessThan = [](const Diagnostic* L, const Diagnostic* R) { - auto N1 = L->getNode(); - auto N2 = R->getNode(); - if (N1 == nullptr && N2 == nullptr) { - return false; - } - if (N1 == nullptr) { - return true; - } - if (N2 == nullptr) { - return false; - } - return N1->getStartLine() < N2->getStartLine() || N1->getStartColumn() < N2->getStartColumn(); - }; - std::sort(DS.Diagnostics.begin(), DS.Diagnostics.end(), lessThan); + if (IsVerify) { - for (auto D: DS.Diagnostics) { - DE.addDiagnostic(D); - } + struct Visitor : public CSTVisitor { + Checker& C; + DiagnosticEngine& DE; + void visitExpression(Expression* N) { + for (auto A: N->Annotations) { + if (A->getKind() == NodeKind::TypeAssertAnnotation) { + auto Left = C.getType(N); + auto Right = static_cast(A)->getTypeExpression()->getType(); + std::cerr << "verify " << describe(Left) << " == " << describe(Right) << std::endl; + if (*Left != *Right) { + DE.add(Left, Right, TypePath(), TypePath(), A); + } + } + } + visitEachChild(N); + } + }; + + Visitor V { {}, TheChecker, DE }; + for (auto SF: SourceFiles) { + V.visit(SF); + } + + struct EDVisitor : public CSTVisitor { + std::multimap Expected; + void visitExpressionAnnotation(ExpressionAnnotation* N) { + if (N->getExpression()->is()) { + auto CE = static_cast(N->getExpression()); + if (CE->Function->is()) { + auto RE = static_cast(CE->Function); + if (RE->getNameAsString() == "expect_diagnostic") { + ZEN_ASSERT(CE->Args.size() == 1 && CE->Args[0]->is()); + Expected.emplace(N->Parent->getStartLine(), static_cast(CE->Args[0])->getAsInt()); + } + } + } + } + }; + + EDVisitor V1; + for (auto SF: SourceFiles) { + V1.visit(SF); + } + + for (auto D: DS.Diagnostics) { + auto N = D->getNode(); + if (!N) { + DE.addDiagnostic(D); + } else { + auto Line = N->getStartLine(); + auto Match = V1.Expected.find(Line); + if (Match != V1.Expected.end() && Match->second == D->getCode()) { + std::cerr << "skipped 1 diagnostic" << std::endl; + } else { + DE.addDiagnostic(D); + } + } + } + + } else { + + DS.sort(); + for (auto D: DS.Diagnostics) { + DE.addDiagnostic(D); + } + + if (DE.hasError()) { + return 1; + } - if (DE.hasError()) { - return 1; } if (Name == "eval") { diff --git a/test/checker/local_constraints_polymorphic_variable.bolt b/test/checker/local_constraints_polymorphic_variable.bolt new file mode 100644 index 000000000..92c70c837 --- /dev/null +++ b/test/checker/local_constraints_polymorphic_variable.bolt @@ -0,0 +1,8 @@ + +let fac n. + n + 1 + return n + +@expect_diagnostic 2010 +(@:Int fac 1) + (@:Bool True) +