From b4d54f025c039df0871ace1648eee0e7bb7da82d Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Thu, 25 Aug 2022 19:04:25 +0200 Subject: [PATCH] Improve type inference and some minor updates --- include/bolt/CST.hpp | 36 ++++++ include/bolt/Checker.hpp | 72 ++++++----- include/bolt/Parser.hpp | 6 +- src/CST.cc | 34 +++++ src/Checker.cc | 265 +++++++++++++++++++++++++++++++-------- src/Diagnostics.cc | 13 +- src/Parser.cc | 20 ++- src/Scanner.cc | 4 +- 8 files changed, 357 insertions(+), 93 deletions(-) diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index f590515b4..c7f0e2116 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -21,6 +21,7 @@ namespace bolt { RBracket, LBrace, RBrace, + RArrow, LetKeyword, MutKeyword, PubKeyword, @@ -40,6 +41,7 @@ namespace bolt { IntegerLiteral, QualifiedName, ReferenceTypeExpression, + ArrowTypeExpression, BindPattern, ReferenceExpression, ConstantExpression, @@ -168,6 +170,18 @@ namespace bolt { }; + class RArrow : public Token { + public: + + RArrow(TextLoc StartLoc): + Token(NodeType::RArrow, StartLoc) {} + + std::string getText() const override; + + ~RArrow(); + + }; + class Dot : public Token { public: @@ -528,6 +542,28 @@ namespace bolt { }; + class ArrowTypeExpression : public TypeExpression { + public: + + std::vector ParamTypes; + TypeExpression* ReturnType; + + inline ArrowTypeExpression( + std::vector ParamTypes, + TypeExpression* ReturnType + ): TypeExpression(NodeType::ArrowTypeExpression), + ParamTypes(ParamTypes), + ReturnType(ReturnType) {} + + void setParents() override; + + Token* getFirstToken() override; + Token* getLastToken() override; + + ~ArrowTypeExpression(); + + }; + class Pattern : public Node { public: diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index dcced3829..747f732a3 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -105,22 +105,24 @@ namespace bolt { class Constraint; + using ConstraintSet = std::vector; + class Forall { public: - TVSet TVs; - std::vector Constraints; + TVSet* TVs; + ConstraintSet* Constraints; Type* Type; inline Forall(class Type* Type): - Type(Type) {} + TVs(nullptr), Constraints(nullptr), Type(Type) {} inline Forall( - TVSet TVs, - std::vector Constraints, + TVSet& TVs, + ConstraintSet& Constraints, class Type* Type - ): TVs(TVs), - Constraints(Constraints), + ): TVs(&TVs), + Constraints(&Constraints), Type(Type) {} }; @@ -184,19 +186,7 @@ namespace bolt { }; - class TypeEnv { - - std::unordered_map Mapping; - - public: - - void add(ByteString Name, Scheme S); - - Scheme* lookup(ByteString Name); - - Type* lookupMono(ByteString Name); - - }; + using TypeEnv = std::unordered_map; enum class ConstraintKind { Equal, @@ -217,12 +207,12 @@ namespace bolt { return Kind; } + Constraint* substitute(const TVSub& Sub); + virtual ~Constraint() {} }; - using ConstraintSet = std::vector; - class CEqual : public Constraint { public: @@ -238,9 +228,9 @@ namespace bolt { class CMany : public Constraint { public: - ConstraintSet Constraints; + ConstraintSet& Constraints; - inline CMany(ConstraintSet Constraints): + inline CMany(ConstraintSet& Constraints): Constraint(ConstraintKind::Many), Constraints(Constraints) {} }; @@ -254,18 +244,28 @@ namespace bolt { }; class InferContext { - - ConstraintSet& Constraints; - public: - TypeEnv& Env; + TVSet TVs; + ConstraintSet Constraints; + TypeEnv Env; - inline InferContext(ConstraintSet& Constraints, TypeEnv& Env): - Constraints(Constraints), Env(Env) {} + InferContext* Parent; + + inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env): + Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env) {} + + inline InferContext(InferContext* Parent = nullptr): + Parent(Parent) {} void addConstraint(Constraint* C); + void addBinding(ByteString Name, Scheme Scm); + + Type* lookupMono(ByteString Name); + + Scheme* lookup(ByteString Name); + }; class Checker { @@ -275,14 +275,18 @@ namespace bolt { size_t nextConTypeId = 0; size_t nextTypeVarId = 0; - Type* inferExpression(Expression* Expression, InferContext& Env); + Type* inferExpression(Expression* Expression, InferContext& Ctx); + Type* inferTypeExpression(TypeExpression* TE, InferContext& Ctx); - void infer(Node* node, InferContext& Env); + void inferBindings(Pattern* Pattern, Type* T, InferContext& Ctx, ConstraintSet& Constraints, TVSet& Tvs); + + void infer(Node* node, InferContext& Ctx); TCon* createPrimConType(); - TVar* createTypeVar(); - Type* instantiate(Scheme& S); + TVar* createTypeVar(InferContext& Ctx); + + Type* instantiate(Scheme& S, InferContext& Ctx, Node* Source); bool unify(Type* A, Type* B, TVSub& Solution); diff --git a/include/bolt/Parser.hpp b/include/bolt/Parser.hpp index 1cbaa1752..177eb3372 100644 --- a/include/bolt/Parser.hpp +++ b/include/bolt/Parser.hpp @@ -70,6 +70,10 @@ namespace bolt { Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence); + TypeExpression* parsePrimitiveTypeExpression(); + + Expression* parsePrimitiveExpression(); + public: Parser(TextFile& File, Stream& S); @@ -86,8 +90,6 @@ namespace bolt { Expression* parseUnaryExpression(); - Expression* parsePrimitiveExpression(); - Expression* parseExpression(); Expression* parseCallExpression(); diff --git a/src/CST.cc b/src/CST.cc index 66d2c76f9..a5c4d7526 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -43,6 +43,15 @@ namespace bolt { Name->Parent = this; Name->setParents(); } + + void ArrowTypeExpression::setParents() { + for (auto ParamType: ParamTypes) { + ParamType->Parent = this; + ParamType->setParents(); + } + ReturnType->Parent = this; + ReturnType->setParents(); + } void BindPattern::setParents() { Name->Parent = this; @@ -179,6 +188,9 @@ namespace bolt { Colon::~Colon() { } + RArrow::~RArrow() { + } + Dot::~Dot() { } @@ -268,6 +280,13 @@ namespace bolt { Name->unref(); } + ArrowTypeExpression::~ArrowTypeExpression() { + for (auto ParamType: ParamTypes) { + ParamType->unref(); + } + ReturnType->unref(); + } + Pattern::~Pattern() { } @@ -401,6 +420,17 @@ namespace bolt { return Name->getFirstToken(); } + Token* ArrowTypeExpression::getFirstToken() { + if (ParamTypes.size()) { + return ParamTypes.front()->getFirstToken(); + } + return ReturnType->getFirstToken(); + } + + Token* ArrowTypeExpression::getLastToken() { + return ReturnType->getLastToken(); + } + Token* BindPattern::getFirstToken() { return Name; } @@ -573,6 +603,10 @@ namespace bolt { return ":"; } + std::string RArrow::getText() const { + return "->"; + } + std::string Dot::getText() const { return "."; } diff --git a/src/Checker.cc b/src/Checker.cc index 850ff637f..872523cdc 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -9,27 +9,7 @@ namespace bolt { - Scheme* TypeEnv::lookup(ByteString Name) { - auto Match = Mapping.find(Name); - if (Match == Mapping.end()) { - return {}; - } - return &Match->second; - } - - Type* TypeEnv::lookupMono(ByteString Name) { - auto Match = Mapping.find(Name); - if (Match == Mapping.end()) { - return nullptr; - } - auto& F = Match->second.as(); - ZEN_ASSERT(F.TVs.empty()); - return F.Type; - } - - void TypeEnv::add(ByteString Name, Scheme S) { - Mapping.emplace(Name, S); - } + std::string describe(const Type* Ty); bool Type::hasTypeVar(const TVar* TV) { switch (Kind) { @@ -66,7 +46,7 @@ namespace bolt { { auto Y = static_cast(this); auto Match = Sub.find(Y); - return Match != Sub.end() ? Match->second : Y; + return Match != Sub.end() ? Match->second->substitute(Sub) : Y; } case TypeKind::Arrow: { @@ -87,11 +67,60 @@ namespace bolt { for (auto Arg: Y->Args) { NewArgs.push_back(Arg->substitute(Sub)); } - return new TCon(Y->Id, Y->Args, Y->DisplayName); + return new TCon(Y->Id, NewArgs, Y->DisplayName); } } } + Constraint* Constraint::substitute(const TVSub &Sub) { + switch (Kind) { + case ConstraintKind::Equal: + { + auto Y = static_cast(this); + return new CEqual(Y->Left->substitute(Sub), Y->Right->substitute(Sub), Y->Source); + } + case ConstraintKind::Many: + { + auto Y = static_cast(this); + auto NewConstraints = new ConstraintSet(); + for (auto Element: Y->Constraints) { + NewConstraints->push_back(Element->substitute(Sub)); + } + return new CMany(*NewConstraints); + } + case ConstraintKind::Empty: + return this; + } + } + + Scheme* InferContext::lookup(ByteString Name) { + InferContext* Curr = this; + for (;;) { + auto Match = Curr->Env.find(Name); + if (Match != Curr->Env.end()) { + return &Match->second; + } + Curr = Curr->Parent; + if (Curr == nullptr) { + return nullptr; + } + } + } + + Type* InferContext::lookupMono(ByteString Name) { + auto Scm = lookup(Name); + if (Scm == nullptr) { + return nullptr; + } + auto& F = Scm->as(); + ZEN_ASSERT(F.TVs == nullptr || F.TVs->empty()); + return F.Type; + } + + void InferContext::addBinding(ByteString Name, Scheme S) { + Env.emplace(Name, S); + } + void InferContext::addConstraint(Constraint *C) { Constraints.push_back(C); } @@ -114,7 +143,57 @@ namespace bolt { case NodeType::LetDeclaration: { - // TODO + auto Y = static_cast(X); + + auto NewCtx = new InferContext { Ctx }; + + Type* Ty; + if (Y->TypeAssert) { + Ty = inferTypeExpression(Y->TypeAssert->TypeExpression, *NewCtx); + } else { + Ty = createTypeVar(*NewCtx); + } + + std::vector ParamTypes; + Type* RetType; + + for (auto Param: Y->Params) { + // TODO incorporate Param->TypeAssert or make it a kind of pattern + TVar* TV = createTypeVar(*NewCtx); + TVSet NoTVs; + ConstraintSet NoConstraints; + inferBindings(Param->Pattern, TV, *NewCtx, NoConstraints, NoTVs); + ParamTypes.push_back(TV); + } + + if (Y->Body) { + switch (Y->Body->Type) { + case NodeType::LetExprBody: + { + auto Z = static_cast(Y->Body); + RetType = inferExpression(Z->Expression, *NewCtx); + break; + } + case NodeType::LetBlockBody: + { + auto Z = static_cast(Y->Body); + RetType = createTypeVar(*NewCtx); + for (auto Element: Z->Elements) { + infer(Element, *NewCtx); + } + break; + } + default: + ZEN_UNREACHABLE + } + } else { + RetType = createTypeVar(*NewCtx); + } + + NewCtx->addConstraint(new CEqual { Ty, new TArrow(ParamTypes, RetType), X }); + + inferBindings(Y->Pattern, Ty, Ctx, NewCtx->Constraints, NewCtx->TVs); + break; } @@ -133,21 +212,43 @@ namespace bolt { } - TVar* Checker::createTypeVar() { - return new TVar(nextTypeVarId++); + TVar* Checker::createTypeVar(InferContext& Ctx) { + auto TV = new TVar(nextTypeVarId++); + Ctx.TVs.emplace(TV); + return TV; } - Type* Checker::instantiate(Scheme& S) { + Type* Checker::instantiate(Scheme& S, InferContext& Ctx, Node* Source) { switch (S.getKind()) { case SchemeKind::Forall: { auto& F = S.as(); + TVSub Sub; - for (auto TV: F.TVs) { - Sub[TV] = createTypeVar(); + if (F.TVs) { + for (auto TV: *F.TVs) { + Sub[TV] = createTypeVar(Ctx); + } } + + if (F.Constraints) { + + for (auto Constraint: *F.Constraints) { + + auto NewConstraint = Constraint->substitute(Sub); + + // This makes error messages prettier by relating the typing failure + // to the call site rather than the definition. + if (NewConstraint->getKind() == ConstraintKind::Equal) { + static_cast(NewConstraint)->Source = Source; + } + + Ctx.addConstraint(NewConstraint); + } + } + return F.Type->substitute(Sub); } @@ -155,6 +256,37 @@ namespace bolt { } + Type* Checker::inferTypeExpression(TypeExpression* X, InferContext& Ctx) { + + switch (X->Type) { + + case NodeType::ReferenceTypeExpression: + { + auto Y = static_cast(X); + auto Ty = Ctx.lookupMono(Y->Name->Name->Text); + if (Ty == nullptr) { + DE.add(Y->Name->Name->Text, Y->Name->Name); + return new TAny(); + } + return Ty; + } + + case NodeType::ArrowTypeExpression: + { + auto Y = static_cast(X); + std::vector ParamTypes; + for (auto ParamType: Y->ParamTypes) { + ParamTypes.push_back(inferTypeExpression(ParamType, Ctx)); + } + auto ReturnType = inferTypeExpression(Y->ReturnType, Ctx); + return new TArrow(ParamTypes, ReturnType); + } + + default: + ZEN_UNREACHABLE + + } + } Type* Checker::inferExpression(Expression* X, InferContext& Ctx) { @@ -166,10 +298,10 @@ namespace bolt { Type* Ty = nullptr; switch (Y->Token->Type) { case NodeType::IntegerLiteral: - Ty = Ctx.Env.lookupMono("Int"); + Ty = Ctx.lookupMono("Int"); break; case NodeType::StringLiteral: - Ty = Ctx.Env.lookupMono("String"); + Ty = Ctx.lookupMono("String"); break; default: ZEN_UNREACHABLE @@ -182,24 +314,37 @@ namespace bolt { { auto Y = static_cast(X); ZEN_ASSERT(Y->Name->ModulePath.empty()); - auto Scm = Ctx.Env.lookup(Y->Name->Name->Text); + auto Scm = Ctx.lookup(Y->Name->Name->Text); if (Scm == nullptr) { DE.add(Y->Name->Name->Text, Y->Name); return new TAny(); } - return instantiate(*Scm); + return instantiate(*Scm, Ctx, X); + } + + case NodeType::CallExpression: + { + auto Y = static_cast(X); + auto OpTy = inferExpression(Y->Function, Ctx); + auto RetType = createTypeVar(Ctx); + std::vector ArgTypes; + for (auto Arg: Y->Args) { + ArgTypes.push_back(inferExpression(Arg, Ctx)); + } + Ctx.addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X }); + return RetType; } case NodeType::InfixExpression: { auto Y = static_cast(X); - auto Scm = Ctx.Env.lookup(Y->Operator->getText()); + auto Scm = Ctx.lookup(Y->Operator->getText()); if (Scm == nullptr) { DE.add(Y->Operator->getText(), Y->Operator); return new TAny(); } - auto OpTy = instantiate(*Scm); - auto RetTy = createTypeVar(); + auto OpTy = instantiate(*Scm, Ctx, Y->Operator); + auto RetTy = createTypeVar(Ctx); std::vector ArgTys; ArgTys.push_back(inferExpression(Y->LHS, Ctx)); ArgTys.push_back(inferExpression(Y->RHS, Ctx)); @@ -214,24 +359,41 @@ namespace bolt { } + void Checker::inferBindings(Pattern* Pattern, Type* Type, InferContext& Ctx, ConstraintSet& Constraints, TVSet& TVs) { + + switch (Pattern->Type) { + + case NodeType::BindPattern: + Ctx.addBinding(static_cast(Pattern)->Name->Text, Forall(TVs, Constraints, Type)); + break; + + default: + ZEN_UNREACHABLE + + } + } + void Checker::check(SourceFile *SF) { - TypeEnv Global; + InferContext Toplevel; auto StringTy = new TCon(nextConTypeId++, {}, "String"); - Global.add("String", Forall(StringTy)); auto IntTy = new TCon(nextConTypeId++, {}, "Int"); - Global.add("Int", Forall(IntTy)); - Global.add("+", Forall(new TArrow({ IntTy, IntTy }, IntTy))); - ConstraintSet Constraints; - InferContext Toplevel { Constraints, Global }; + auto BoolTy = new TCon(nextConTypeId++, {}, "Bool"); + Toplevel.addBinding("String", Forall(StringTy)); + Toplevel.addBinding("Int", Forall(IntTy)); + Toplevel.addBinding("Bool", Forall(BoolTy)); + Toplevel.addBinding("+", Forall(new TArrow({ IntTy, IntTy }, IntTy))); + Toplevel.addBinding("-", Forall(new TArrow({ IntTy, IntTy }, IntTy))); + Toplevel.addBinding("*", Forall(new TArrow({ IntTy, IntTy }, IntTy))); + Toplevel.addBinding("/", Forall(new TArrow({ IntTy, IntTy }, IntTy))); infer(SF, Toplevel); - solve(new CMany(Constraints)); + solve(new CMany(Toplevel.Constraints)); } void Checker::solve(Constraint* Constraint) { std::stack Queue; Queue.push(Constraint); - TVSub Sub; + TVSub Solution; while (!Queue.empty()) { @@ -256,8 +418,9 @@ namespace bolt { case ConstraintKind::Equal: { auto Y = static_cast(Constraint); - if (!unify(Y->Left, Y->Right, Sub)) { - DE.add(Y->Left, Y->Right, Y->Source); + std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl; + if (!unify(Y->Left, Y->Right, Solution)) { + DE.add(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source); } break; } @@ -288,6 +451,7 @@ namespace bolt { auto Y = static_cast(A); if (B->hasTypeVar(Y)) { // TODO occurs check + return false; } Solution[Y] = B; return true; @@ -297,11 +461,14 @@ namespace bolt { return unify(B, A, Solution); } + if (A->getKind() == TypeKind::Any || B->getKind() == TypeKind::Any) { + return true; + } + if (A->getKind() == TypeKind::Arrow && B->getKind() == TypeKind::Arrow) { auto Y = static_cast(A); auto Z = static_cast(B); if (Y->ParamTypes.size() != Z->ParamTypes.size()) { - // TODO diagnostic return false; } auto Count = Y->ParamTypes.size(); @@ -313,11 +480,10 @@ namespace bolt { return unify(Y->ReturnType, Z->ReturnType, Solution); } - if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Arrow) { + if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) { auto Y = static_cast(A); auto Z = static_cast(B); if (Y->Id != Z->Id) { - // TODO diagnostic return false; } ZEN_ASSERT(Y->Args.size() == Z->Args.size()); @@ -330,7 +496,6 @@ namespace bolt { return true; } - // TODO diagnostic return false; } diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index f3c29d31f..5713cfbd3 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -95,7 +95,7 @@ namespace bolt { } } - static std::string describe(const Type* Ty) { + std::string describe(const Type* Ty) { switch (Ty->getKind()) { case TypeKind::Any: return "any"; @@ -320,10 +320,13 @@ namespace bolt { case DiagnosticKind::BindingNotFound: { auto E = static_cast(D); - Out << ANSI_BOLD ANSI_FG_RED "error: " ANSI_RESET "binding '" << E.Name << "' was not found\n"; - //if (E.Initiator != nullptr) { - // writeExcerpt(E.Initiator->getRange()); - //} + Out << ANSI_BOLD ANSI_FG_RED "error: " ANSI_RESET "binding '" << E.Name << "' was not found\n\n"; + if (E.Initiator != nullptr) { + auto Range = E.Initiator->getRange(); + //std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n"; + writeExcerpt(E.Initiator->getSourceFile()->getTextFile(), Range, Range, Color::Red); + Out << "\n"; + } break; } diff --git a/src/Parser.cc b/src/Parser.cc index debf4218b..093929544 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -106,7 +106,7 @@ namespace bolt { return new QualifiedName(ModulePath, static_cast(Name)); } - TypeExpression* Parser::parseTypeExpression() { + TypeExpression* Parser::parsePrimitiveTypeExpression() { auto T0 = Tokens.peek(); switch (T0->Type) { case NodeType::Identifier: @@ -116,6 +116,24 @@ namespace bolt { } } + TypeExpression* Parser::parseTypeExpression() { + auto RetType = parsePrimitiveTypeExpression(); + std::vector ParamTypes; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->Type != NodeType::RArrow) { + break; + } + Tokens.get(); + ParamTypes.push_back(RetType); + RetType = parsePrimitiveTypeExpression(); + } + if (ParamTypes.size()) { + return new ArrowTypeExpression(ParamTypes, RetType); + } + return RetType; + } + Expression* Parser::parsePrimitiveExpression() { auto T0 = Tokens.peek(); switch (T0->Type) { diff --git a/src/Scanner.cc b/src/Scanner.cc index c0d42b6f6..81c0c333d 100644 --- a/src/Scanner.cc +++ b/src/Scanner.cc @@ -294,7 +294,9 @@ after_string_contents: Text.push_back(static_cast(C1)); getChar(); } - if (Text == "=") { + if (Text == "->") { + return new RArrow(StartLoc); + } else if (Text == "=") { return new Equals(StartLoc); } else if (Text.back() == '=' && Text[Text.size()-2] != '=') { return new Assignment(Text.substr(0, Text.size()-1), StartLoc);