diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index c7f0e2116..45efafcf9 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -29,6 +29,9 @@ namespace bolt { ReturnKeyword, ModKeyword, StructKeyword, + ElifKeyword, + IfKeyword, + ElseKeyword, Invalid, EndOfFile, BlockStart, @@ -50,6 +53,8 @@ namespace bolt { UnaryExpression, ExpressionStatement, ReturnStatement, + IfStatement, + IfStatementPart, TypeAssert, Param, LetBlockBody, @@ -338,6 +343,42 @@ namespace bolt { }; + class ElseKeyword : public Token { + public: + + ElseKeyword(TextLoc StartLoc): + Token(NodeType::ElseKeyword, StartLoc) {} + + std::string getText() const override; + + ~ElseKeyword(); + + }; + + class ElifKeyword : public Token { + public: + + ElifKeyword(TextLoc StartLoc): + Token(NodeType::ElifKeyword, StartLoc) {} + + std::string getText() const override; + + ~ElifKeyword(); + + }; + + class IfKeyword : public Token { + public: + + IfKeyword(TextLoc StartLoc): + Token(NodeType::IfKeyword, StartLoc) {} + + std::string getText() const override; + + ~IfKeyword(); + + }; + class ModKeyword : public Token { public: @@ -731,6 +772,51 @@ namespace bolt { }; + class IfStatementPart : public Node { + public: + + Token* Keyword; + Expression* Test; + BlockStart* BlockStart; + std::vector Elements; + + inline IfStatementPart( + Token* Keyword, + Expression* Test, + class BlockStart* BlockStart, + std::vector Elements + ): Node(NodeType::IfStatementPart), + Keyword(Keyword), + Test(Test), + BlockStart(BlockStart), + Elements(Elements) {} + + void setParents() override; + + Token* getFirstToken() override; + Token* getLastToken() override; + + ~IfStatementPart(); + + }; + + class IfStatement : public Statement { + public: + + std::vector Parts; + + inline IfStatement(std::vector Parts): + Statement(NodeType::IfStatement), Parts(Parts) {} + + void setParents() override; + + Token* getFirstToken() override; + Token* getLastToken() override; + + ~IfStatement(); + + }; + class ReturnStatement : public Statement { public: diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index 747f732a3..0cb90feaf 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -1,11 +1,11 @@ #pragma once -#include "bolt/Diagnostics.hpp" #include "zen/config.hpp" #include "bolt/ByteString.hpp" +#include #include #include #include @@ -13,8 +13,11 @@ namespace bolt { + class DiagnosticEngine; class Node; class Expression; + class TypeExpression; + class Pattern; class SourceFile; class Type; @@ -28,6 +31,7 @@ namespace bolt { Con, Arrow, Any, + Tuple, }; class Type { @@ -88,6 +92,16 @@ namespace bolt { }; + class TTuple : public Type { + public: + + std::vector ElementTypes; + + inline TTuple(std::vector ElementTypes): + Type(TypeKind::Tuple), ElementTypes(ElementTypes) {} + + }; + class TAny : public Type { public: @@ -115,7 +129,7 @@ namespace bolt { Type* Type; inline Forall(class Type* Type): - TVs(nullptr), Constraints(nullptr), Type(Type) {} + TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {} inline Forall( TVSet& TVs, @@ -228,10 +242,10 @@ namespace bolt { class CMany : public Constraint { public: - ConstraintSet& Constraints; + ConstraintSet& Elements; inline CMany(ConstraintSet& Constraints): - Constraint(ConstraintKind::Many), Constraints(Constraints) {} + Constraint(ConstraintKind::Many), Elements(Constraints) {} }; @@ -249,14 +263,15 @@ namespace bolt { TVSet TVs; ConstraintSet Constraints; TypeEnv Env; + Type* ReturnType; InferContext* Parent; - inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env): - Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env) {} + inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env, Type* ReturnType): + Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env), ReturnType(ReturnType) {} inline InferContext(InferContext* Parent = nullptr): - Parent(Parent) {} + Parent(Parent), ReturnType(nullptr) {} void addConstraint(Constraint* C); @@ -275,6 +290,14 @@ namespace bolt { size_t nextConTypeId = 0; size_t nextTypeVarId = 0; + Type* BoolType; + Type* IntType; + Type* StringType; + + std::stack Contexts; + + void addConstraint(Constraint* Constraint); + Type* inferExpression(Expression* Expression, InferContext& Ctx); Type* inferTypeExpression(TypeExpression* TE, InferContext& Ctx); diff --git a/include/bolt/Parser.hpp b/include/bolt/Parser.hpp index 177eb3372..50a0ecc75 100644 --- a/include/bolt/Parser.hpp +++ b/include/bolt/Parser.hpp @@ -68,6 +68,8 @@ namespace bolt { Token* peekFirstTokenAfterModifiers(); + Token* expectToken(NodeType Ty); + Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence); TypeExpression* parsePrimitiveTypeExpression(); @@ -94,6 +96,10 @@ namespace bolt { Expression* parseCallExpression(); + IfStatement* parseIfStatement(); + + ReturnStatement* parseReturnStatement(); + ExpressionStatement* parseExpressionStatement(); Node* parseLetBodyElement(); diff --git a/src/CST.cc b/src/CST.cc index a5c4d7526..6be1d18e8 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -99,6 +99,26 @@ namespace bolt { Expression->setParents(); } + void IfStatementPart::setParents() { + Keyword->Parent = this; + if (Test) { + Test->Parent = this; + Test->setParents(); + } + BlockStart->Parent = this; + for (auto Element: Elements) { + Element->Parent = this; + Element->setParents(); + } + } + + void IfStatement::setParents() { + for (auto Part: Parts) { + Part->Parent = this; + Part->setParents(); + } + } + void TypeAssert::setParents() { Colon->Parent = this; TypeExpression->Parent = this; @@ -230,6 +250,15 @@ namespace bolt { ReturnKeyword::~ReturnKeyword() { } + IfKeyword::~IfKeyword() { + } + + ElifKeyword::~ElifKeyword() { + } + + ElseKeyword::~ElseKeyword() { + } + ModKeyword::~ModKeyword() { } @@ -335,6 +364,23 @@ namespace bolt { Expression->unref(); } + IfStatementPart::~IfStatementPart() { + Keyword->unref(); + if (Test) { + Test->unref(); + } + BlockStart->unref(); + for (auto Element: Elements) { + Element->unref(); + } + } + + IfStatement::~IfStatement() { + for (auto Part: Parts) { + Part->unref(); + } + } + TypeAssert::~TypeAssert() { Colon->unref(); TypeExpression->unref(); @@ -501,6 +547,27 @@ namespace bolt { return ReturnKeyword; } + Token* IfStatementPart::getFirstToken() { + return Keyword; + } + + Token* IfStatementPart::getLastToken() { + if (Elements.size()) { + return Elements.back()->getLastToken(); + } + return BlockStart; + } + + Token* IfStatement::getFirstToken() { + ZEN_ASSERT(Parts.size()); + return Parts.front()->getFirstToken(); + } + + Token* IfStatement::getLastToken() { + ZEN_ASSERT(Parts.size()); + return Parts.back()->getLastToken(); + } + Token* TypeAssert::getFirstToken() { return Colon; } @@ -655,6 +722,18 @@ namespace bolt { return "return"; } + std::string IfKeyword::getText() const { + return "if"; + } + + std::string ElseKeyword::getText() const { + return "else"; + } + + std::string ElifKeyword::getText() const { + return "elif"; + } + std::string ModKeyword::getText() const { return "mod"; } diff --git a/src/Checker.cc b/src/Checker.cc index 872523cdc..daa219235 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -35,6 +35,16 @@ namespace bolt { } return false; } + case TypeKind::Tuple: + { + auto Y = static_cast(this); + for (auto Ty: Y->ElementTypes) { + if (Ty->hasTypeVar(TV)) { + return true; + } + } + return false; + } case TypeKind::Any: return false; } @@ -69,6 +79,15 @@ namespace bolt { } return new TCon(Y->Id, NewArgs, Y->DisplayName); } + case TypeKind::Tuple: + { + auto Y = static_cast(this); + std::vector NewElementTypes; + for (auto Ty: Y->ElementTypes) { + NewElementTypes.push_back(Ty->substitute(Sub)); + } + return new TTuple(NewElementTypes); + } } } @@ -126,7 +145,11 @@ namespace bolt { } Checker::Checker(DiagnosticEngine& DE): - DE(DE) {} + DE(DE) { + BoolType = new TCon(nextConTypeId++, {}, "Bool"); + IntType = new TCon(nextConTypeId++, {}, "Int"); + StringType = new TCon(nextConTypeId++, {}, "String"); + } void Checker::infer(Node* X, InferContext& Ctx) { @@ -141,6 +164,20 @@ namespace bolt { break; } + case NodeType::IfStatement: + { + auto Y = static_cast(X); + for (auto Part: Y->Parts) { + if (Part->Test != nullptr) { + Ctx.addConstraint(new CEqual { BoolType, inferExpression(Part->Test, Ctx), Part->Test }); + } + for (auto Element: Part->Elements) { + infer(Element, Ctx); + } + } + break; + } + case NodeType::LetDeclaration: { auto Y = static_cast(X); @@ -178,6 +215,7 @@ namespace bolt { { auto Z = static_cast(Y->Body); RetType = createTypeVar(*NewCtx); + NewCtx->ReturnType = RetType; for (auto Element: Z->Elements) { infer(Element, *NewCtx); } @@ -197,6 +235,19 @@ namespace bolt { break; } + case NodeType::ReturnStatement: + { + auto Y = static_cast(X); + Type* ReturnType; + if (Y->Expression) { + ReturnType = inferExpression(Y->Expression, Ctx); + } else { + ReturnType = new TTuple({}); + } + ZEN_ASSERT(Ctx.ReturnType != nullptr); + Ctx.addConstraint(new CEqual { ReturnType, Ctx.ReturnType, X }); + break; + } case NodeType::ExpressionStatement: { @@ -375,16 +426,15 @@ namespace bolt { void Checker::check(SourceFile *SF) { InferContext Toplevel; - auto StringTy = new TCon(nextConTypeId++, {}, "String"); - auto IntTy = new TCon(nextConTypeId++, {}, "Int"); - auto BoolTy = new TCon(nextConTypeId++, {}, "Bool"); - Toplevel.addBinding("String", Forall(StringTy)); - Toplevel.addBinding("Int", Forall(IntTy)); - Toplevel.addBinding("Bool", Forall(BoolTy)); - Toplevel.addBinding("+", Forall(new TArrow({ IntTy, IntTy }, IntTy))); - Toplevel.addBinding("-", Forall(new TArrow({ IntTy, IntTy }, IntTy))); - Toplevel.addBinding("*", Forall(new TArrow({ IntTy, IntTy }, IntTy))); - Toplevel.addBinding("/", Forall(new TArrow({ IntTy, IntTy }, IntTy))); + Toplevel.addBinding("String", Forall(StringType)); + Toplevel.addBinding("Int", Forall(IntType)); + Toplevel.addBinding("Bool", Forall(BoolType)); + Toplevel.addBinding("True", Forall(BoolType)); + Toplevel.addBinding("False", Forall(BoolType)); + Toplevel.addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType))); + Toplevel.addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType))); + Toplevel.addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType))); + Toplevel.addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType))); infer(SF, Toplevel); solve(new CMany(Toplevel.Constraints)); } @@ -480,6 +530,22 @@ namespace bolt { return unify(Y->ReturnType, Z->ReturnType, Solution); } + if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) { + auto Y = static_cast(A); + auto Z = static_cast(B); + if (Y->ElementTypes.size() != Z->ElementTypes.size()) { + return false; + } + auto Count = Y->ElementTypes.size(); + bool Success = true; + for (size_t I = 0; I < Count; I++) { + if (!unify(Y->ElementTypes[I], Z->ElementTypes[I], Solution)) { + Success = false; + } + } + return Success; + } + if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) { auto Y = static_cast(A); auto Z = static_cast(B); diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 5713cfbd3..32edbca61 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -129,6 +129,21 @@ namespace bolt { } return Out.str(); } + case TypeKind::Tuple: + { + std::ostringstream Out; + auto Y = static_cast(Ty); + Out << "("; + if (Y->ElementTypes.size()) { + auto Iter = Y->ElementTypes.begin(); + Out << describe(*Iter++); + while (Iter != Y->ElementTypes.end()) { + Out << ", " << describe(*Iter++); + } + } + Out << ")"; + return Out.str(); + } } } diff --git a/src/Parser.cc b/src/Parser.cc index 093929544..702549ed7 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -3,6 +3,7 @@ #include "bolt/Scanner.hpp" #include "bolt/Parser.hpp" #include "bolt/Diagnostics.hpp" +#include #include namespace bolt { @@ -74,6 +75,14 @@ namespace bolt { } \ } + Token* Parser::expectToken(NodeType Type) { + auto T = Tokens.get(); + if (T->Type != Type) { + throw UnexpectedTokenDiagnostic(File, T, std::vector { Type }); \ + } + return T; + } + Pattern* Parser::parsePattern() { auto T0 = Tokens.peek(); switch (T0->Type) { @@ -87,10 +96,7 @@ namespace bolt { QualifiedName* Parser::parseQualifiedName() { std::vector ModulePath; - auto Name = Tokens.get(); - if (Name->Type != NodeType::Identifier) { - throw UnexpectedTokenDiagnostic(File, Name, std::vector { NodeType::Identifier }); - } + auto Name = expectToken(NodeType::Identifier); for (;;) { auto T1 = Tokens.peek(); if (T1->Type != NodeType::Dot) { @@ -156,7 +162,7 @@ namespace bolt { std::vector Args; for (;;) { auto T1 = Tokens.peek(); - if (T1->Type == NodeType::LineFoldEnd || ExprOperators.isInfix(T1)) { + if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) { break; } Args.push_back(parsePrimitiveExpression()); @@ -216,6 +222,52 @@ namespace bolt { return new ExpressionStatement(E); } + ReturnStatement* Parser::parseReturnStatement() { + auto T0 = static_cast(expectToken(NodeType::ReturnKeyword)); + Expression* Expression = nullptr; + auto T1 = Tokens.peek(); + if (T1->Type != NodeType::LineFoldEnd) { + Expression = parseExpression(); + } + BOLT_EXPECT_TOKEN(LineFoldEnd); + return new ReturnStatement(static_cast(T0), Expression); + } + + IfStatement* Parser::parseIfStatement() { + std::vector Parts; + auto T0 = expectToken(NodeType::IfKeyword); + auto Test = parseExpression(); + auto T1 = static_cast(expectToken(NodeType::BlockStart)); + std::vector Then; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->Type == NodeType::BlockEnd) { + Tokens.get(); + break; + } + Then.push_back(parseLetBodyElement()); + } + Parts.push_back(new IfStatementPart(T0, Test, T1, Then)); + BOLT_EXPECT_TOKEN(LineFoldEnd) + auto T3 = Tokens.peek(); + if (T3->Type == NodeType::ElseKeyword) { + Tokens.get(); + auto T4 = static_cast(expectToken(NodeType::BlockStart)); + std::vector Else; + for (;;) { + auto T5 = Tokens.peek(); + if (T5->Type == NodeType::BlockEnd) { + Tokens.get(); + break; + } + Else.push_back(parseLetBodyElement()); + } + Parts.push_back(new IfStatementPart(T3, nullptr, T4, Else)); + BOLT_EXPECT_TOKEN(LineFoldEnd) + } + return new IfStatement(Parts); + } + LetDeclaration* Parser::parseLetDeclaration() { PubKeyword* Pub = nullptr; @@ -316,6 +368,10 @@ after_params: switch (T0->Type) { case NodeType::LetKeyword: return parseLetDeclaration(); + case NodeType::ReturnKeyword: + return parseReturnStatement(); + case NodeType::IfKeyword: + return parseIfStatement(); default: return parseExpressionStatement(); } @@ -326,6 +382,8 @@ after_params: switch (T0->Type) { case NodeType::LetKeyword: return parseLetDeclaration(); + case NodeType::IfKeyword: + return parseIfStatement(); default: return parseExpressionStatement(); } diff --git a/src/Scanner.cc b/src/Scanner.cc index 81c0c333d..7943e71e5 100644 --- a/src/Scanner.cc +++ b/src/Scanner.cc @@ -64,6 +64,9 @@ namespace bolt { { "return", NodeType::ReturnKeyword }, { "type", NodeType::TypeKeyword }, { "mod", NodeType::ModKeyword }, + { "if", NodeType::IfKeyword }, + { "else", NodeType::ElseKeyword }, + { "elif", NodeType::ElifKeyword }, }; Scanner::Scanner(TextFile& File, Stream& Chars): @@ -209,6 +212,12 @@ digit_finish: return new TypeKeyword(StartLoc); case NodeType::ReturnKeyword: return new ReturnKeyword(StartLoc); + case NodeType::IfKeyword: + return new IfKeyword(StartLoc); + case NodeType::ElifKeyword: + return new ElifKeyword(StartLoc); + case NodeType::ElseKeyword: + return new ElseKeyword(StartLoc); default: ZEN_UNREACHABLE }