diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c40686e8..91f372a47 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,7 @@ if (BOLT_ENABLE_TESTS) add_executable( alltests src/TestText.cc + src/TestChecker.cc ) target_link_libraries( alltests diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 45efafcf9..426cf5c25 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -1,7 +1,9 @@ #ifndef BOLT_CST_HPP #define BOLT_CST_HPP +#include #include +#include #include #include "bolt/Text.hpp" @@ -10,6 +12,10 @@ namespace bolt { + class Token; + class SourceFile; + class Scope; + enum class NodeType { Equals, Colon, @@ -47,6 +53,7 @@ namespace bolt { ArrowTypeExpression, BindPattern, ReferenceExpression, + NestedExpression, ConstantExpression, CallExpression, InfixExpression, @@ -65,8 +72,10 @@ namespace bolt { SourceFile, }; - class Token; - class SourceFile; + struct SymbolPath { + std::vector Modules; + ByteString Name; + }; class Node { @@ -101,10 +110,28 @@ namespace bolt { SourceFile* getSourceFile(); + virtual Scope* getScope(); + virtual ~Node(); }; + class Scope { + + Node* Source; + std::unordered_map Mapping; + + public: + + inline Scope(Node* Source): + Source(Source) {} + + Node* lookup(SymbolPath Path); + + Scope* getParentScope(); + + }; + class Token : public Node { TextLoc StartLoc; @@ -551,6 +578,8 @@ namespace bolt { void setParents() override; + SymbolPath getSymbolPath() const; + ~QualifiedName(); }; @@ -661,6 +690,31 @@ namespace bolt { }; + class NestedExpression : public Expression { + public: + + LParen* LParen; + Expression* Inner; + RParen* RParen; + + inline NestedExpression( + class LParen* LParen, + Expression* Inner, + class RParen* RParen + ): Expression(NodeType::NestedExpression), + LParen(LParen), + Inner(Inner), + RParen(RParen) {} + + void setParents() override; + + Token* getFirstToken() override; + Token* getLastToken() override; + + ~NestedExpression(); + + }; + class ConstantExpression : public Expression { public: @@ -931,9 +985,18 @@ namespace bolt { }; + class Type; + class InferContext; + class LetDeclaration : public Node { + + Scope TheScope; + public: + InferContext* Ctx; + class Type* Ty; + PubKeyword* PubKeyword; LetKeyword* LetKeyword; MutKeyword* MutKeyword; @@ -951,6 +1014,7 @@ namespace bolt { class TypeAssert* TypeAssert, LetBody* Body ): Node(NodeType::LetDeclaration), + TheScope(this), PubKeyword(PubKeyword), LetKeyword(LetKeywod), MutKeyword(MutKeyword), @@ -959,6 +1023,10 @@ namespace bolt { TypeAssert(TypeAssert), Body(Body) {} + inline Scope* getScope() override { + return &TheScope; + } + void setParents() override; Token* getFirstToken() override; @@ -1025,6 +1093,9 @@ namespace bolt { }; class SourceFile : public Node { + + Scope TheScope; + public: TextFile& File; @@ -1032,7 +1103,7 @@ namespace bolt { std::vector Elements; SourceFile(TextFile& File, std::vector Elements): - Node(NodeType::SourceFile), File(File), Elements(Elements) {} + Node(NodeType::SourceFile), TheScope(this), File(File), Elements(Elements) {} inline TextFile& getTextFile() { return File; @@ -1043,6 +1114,10 @@ namespace bolt { Token* getFirstToken() override; Token* getLastToken() override; + inline Scope* getScope() override { + return &TheScope; + } + ~SourceFile(); }; diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index 0cb90feaf..6fa6ceb3b 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -4,8 +4,9 @@ #include "zen/config.hpp" #include "bolt/ByteString.hpp" +#include "bolt/CST.hpp" -#include +#include #include #include #include @@ -15,10 +16,6 @@ namespace bolt { class DiagnosticEngine; class Node; - class Expression; - class TypeExpression; - class Pattern; - class SourceFile; class Type; class TVar; @@ -47,6 +44,14 @@ namespace bolt { bool hasTypeVar(const TVar* TV); + void addTypeVars(TVSet& TVs); + + inline TVSet getTypeVars() { + TVSet Out; + addTypeVars(Out); + return Out; + } + Type* substitute(const TVSub& Sub); inline TypeKind getKind() const noexcept { @@ -273,14 +278,6 @@ namespace bolt { inline InferContext(InferContext* Parent = nullptr): Parent(Parent), ReturnType(nullptr) {} - void addConstraint(Constraint* C); - - void addBinding(ByteString Name, Scheme Scm); - - Type* lookupMono(ByteString Name); - - Scheme* lookup(ByteString Name); - }; class Checker { @@ -290,36 +287,66 @@ namespace bolt { size_t nextConTypeId = 0; size_t nextTypeVarId = 0; + std::unordered_map Mapping; + + std::unordered_map CallGraph; + Type* BoolType; Type* IntType; Type* StringType; - std::stack Contexts; + std::vector Contexts; void addConstraint(Constraint* Constraint); - Type* inferExpression(Expression* Expression, InferContext& Ctx); - Type* inferTypeExpression(TypeExpression* TE, InferContext& Ctx); + void forwardDeclare(Node* Node); - void inferBindings(Pattern* Pattern, Type* T, InferContext& Ctx, ConstraintSet& Constraints, TVSet& Tvs); + Type* inferExpression(Expression* Expression); + Type* inferTypeExpression(TypeExpression* TE); - void infer(Node* node, InferContext& Ctx); + void inferBindings(Pattern* Pattern, Type* T, ConstraintSet& Constraints, TVSet& Tvs); + + void infer(Node* node); TCon* createPrimConType(); - TVar* createTypeVar(InferContext& Ctx); + TVar* createTypeVar(); - Type* instantiate(Scheme& S, InferContext& Ctx, Node* Source); + void addBinding(ByteString Name, Scheme Scm); + + Type* lookupMono(ByteString Name); + + InferContext* lookupCall(Node* Source, SymbolPath Path); + + Type* getReturnType(); + + Scheme* lookup(ByteString Name); + + Type* instantiate(Scheme& S, Node* Source); bool unify(Type* A, Type* B, TVSub& Solution); - void solve(Constraint* Constraint); + void solve(Constraint* Constraint, TVSub& Solution); public: Checker(DiagnosticEngine& DE); - void check(SourceFile* SF); + TVSub check(SourceFile* SF); + + inline Type* getBoolType() { + return BoolType; + } + + inline Type* getStringType() { + return StringType; + } + + inline Type* getIntType() { + return IntType; + } + + Type* getType(Node* Node, const TVSub& Solution); }; diff --git a/src/CST.cc b/src/CST.cc index 6be1d18e8..e83a410eb 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -5,6 +5,25 @@ namespace bolt { + Node* Scope::lookup(SymbolPath Path) { + auto Curr = this; + do { + auto Match = Curr->Mapping.find(Path.Name); + if (Match != Curr->Mapping.end()) { + return Match->second; + } + Curr = Curr->getParentScope(); + } while (Curr != nullptr); + return nullptr; + } + + Scope* Scope::getParentScope() { + if (Source->Parent == nullptr) { + return nullptr; + } + return Source->Parent->getScope(); + } + SourceFile* Node::getSourceFile() { auto CurrNode = this; for (;;) { @@ -23,6 +42,10 @@ namespace bolt { }; } + Scope* Node::getScope() { + return this->Parent->getScope(); + } + TextLoc Token::getEndLoc() { auto EndLoc = StartLoc; EndLoc.advance(getText()); @@ -61,6 +84,13 @@ namespace bolt { Name->Parent = this; } + void NestedExpression::setParents() { + LParen->Parent = this; + Inner->Parent = this; + Inner->setParents(); + RParen->Parent = this; + } + void ConstantExpression::setParents() { Token->Parent = this; } @@ -330,6 +360,12 @@ namespace bolt { Name->unref(); } + NestedExpression::~NestedExpression() { + LParen->unref(); + Inner->unref(); + RParen->unref(); + } + ConstantExpression::~ConstantExpression() { Token->unref(); } @@ -493,6 +529,14 @@ namespace bolt { return Name->getLastToken(); } + Token* NestedExpression::getFirstToken() { + return LParen; + } + + Token* NestedExpression::getLastToken() { + return RParen; + } + Token* ConstantExpression::getFirstToken() { return Token; } @@ -786,5 +830,13 @@ namespace bolt { return ".."; } + SymbolPath QualifiedName::getSymbolPath() const { + std::vector ModuleNames; + for (auto Ident: ModulePath) { + ModuleNames.push_back(Ident->Text); + } + return SymbolPath { ModuleNames, Name->Text }; + } + } diff --git a/src/Checker.cc b/src/Checker.cc index daa219235..ab80fd5e7 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -11,6 +11,41 @@ namespace bolt { std::string describe(const Type* Ty); + void Type::addTypeVars(TVSet& TVs) { + switch (Kind) { + case TypeKind::Var: + TVs.emplace(static_cast(this)); + break; + case TypeKind::Arrow: + { + auto Y = static_cast(this); + for (auto Ty: Y->ParamTypes) { + Ty->addTypeVars(TVs); + } + Y->ReturnType->addTypeVars(TVs); + break; + } + case TypeKind::Con: + { + auto Y = static_cast(this); + for (auto Ty: Y->Args) { + Ty->addTypeVars(TVs); + } + break; + } + case TypeKind::Tuple: + { + auto Y = static_cast(this); + for (auto Ty: Y->ElementTypes) { + Ty->addTypeVars(TVs); + } + break; + } + case TypeKind::Any: + break; + } + } + bool Type::hasTypeVar(const TVar* TV) { switch (Kind) { case TypeKind::Var: @@ -61,32 +96,50 @@ namespace bolt { case TypeKind::Arrow: { auto Y = static_cast(this); + bool Changed = false; std::vector NewParamTypes; for (auto Ty: Y->ParamTypes) { - NewParamTypes.push_back(Ty->substitute(Sub)); + auto NewParamType = Ty->substitute(Sub); + if (NewParamType != Ty) { + Changed = true; + } + NewParamTypes.push_back(NewParamType); } auto NewRetTy = Y->ReturnType->substitute(Sub) ; - return new TArrow(NewParamTypes, NewRetTy); + if (NewRetTy != Y->ReturnType) { + Changed = true; + } + return Changed ? new TArrow(NewParamTypes, NewRetTy) : this; } case TypeKind::Any: return this; case TypeKind::Con: { auto Y = static_cast(this); + bool Changed = false; std::vector NewArgs; for (auto Arg: Y->Args) { - NewArgs.push_back(Arg->substitute(Sub)); + auto NewArg = Arg->substitute(Sub); + if (NewArg != Arg) { + Changed = true; + } + NewArgs.push_back(NewArg); } - return new TCon(Y->Id, NewArgs, Y->DisplayName); + return Changed ? new TCon(Y->Id, NewArgs, Y->DisplayName) : this; } case TypeKind::Tuple: { auto Y = static_cast(this); + bool Changed = false; std::vector NewElementTypes; for (auto Ty: Y->ElementTypes) { - NewElementTypes.push_back(Ty->substitute(Sub)); + auto NewElementType = Ty->substitute(Sub); + if (NewElementType != Ty) { + Changed = true; + } + NewElementTypes.push_back(NewElementType); } - return new TTuple(NewElementTypes); + return Changed ? new TTuple(NewElementTypes) : this; } } } @@ -102,7 +155,7 @@ namespace bolt { { auto Y = static_cast(this); auto NewConstraints = new ConstraintSet(); - for (auto Element: Y->Constraints) { + for (auto Element: Y->Elements) { NewConstraints->push_back(Element->substitute(Sub)); } return new CMany(*NewConstraints); @@ -112,21 +165,25 @@ namespace bolt { } } - Scheme* InferContext::lookup(ByteString Name) { - InferContext* Curr = this; - for (;;) { + Checker::Checker(DiagnosticEngine& DE): + DE(DE) { + BoolType = new TCon(nextConTypeId++, {}, "Bool"); + IntType = new TCon(nextConTypeId++, {}, "Int"); + StringType = new TCon(nextConTypeId++, {}, "String"); + } + + Scheme* Checker::lookup(ByteString Name) { + for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) { + auto Curr = *Iter; auto Match = Curr->Env.find(Name); if (Match != Curr->Env.end()) { return &Match->second; } - Curr = Curr->Parent; - if (Curr == nullptr) { - return nullptr; - } } + return nullptr; } - Type* InferContext::lookupMono(ByteString Name) { + Type* Checker::lookupMono(ByteString Name) { auto Scm = lookup(Name); if (Scm == nullptr) { return nullptr; @@ -136,22 +193,126 @@ namespace bolt { return F.Type; } - void InferContext::addBinding(ByteString Name, Scheme S) { - Env.emplace(Name, S); + void Checker::addBinding(ByteString Name, Scheme S) { + Contexts.back()->Env.emplace(Name, S); } - void InferContext::addConstraint(Constraint *C) { - Constraints.push_back(C); + Type* Checker::getReturnType() { + auto Ty = Contexts.back()->ReturnType; + ZEN_ASSERT(Ty != nullptr); + return Ty; } - Checker::Checker(DiagnosticEngine& DE): - DE(DE) { - BoolType = new TCon(nextConTypeId++, {}, "Bool"); - IntType = new TCon(nextConTypeId++, {}, "Int"); - StringType = new TCon(nextConTypeId++, {}, "String"); + static bool hasTypeVar(TVSet& Set, Type* Type) { + for (auto TV: Type->getTypeVars()) { + if (Set.count(TV)) { + return true; + } + } + return false; + } + + void Checker::addConstraint(Constraint* Constraint) { + switch (Constraint->getKind()) { + case ConstraintKind::Equal: + { + auto Y = static_cast(Constraint); + for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) { + auto& Ctx = **Iter; + if (hasTypeVar(Ctx.TVs, Y->Left) || hasTypeVar(Ctx.TVs, Y->Right)) { + Ctx.Constraints.push_back(Constraint); + return; + } + } + Contexts.front()->Constraints.push_back(Constraint); + //auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth); + //ZEN_ASSERT(I < Contexts.size()); + //auto Ctx = Contexts[I]; + //Ctx->Constraints.push_back(Constraint); + break; + } + case ConstraintKind::Many: + { + auto Y = static_cast(Constraint); + for (auto Element: Y->Elements) { + addConstraint(Element); + } + break; + } + case ConstraintKind::Empty: + break; + } + } + + void Checker::forwardDeclare(Node* X) { + + switch (X->Type) { + + case NodeType::ExpressionStatement: + case NodeType::ReturnStatement: + case NodeType::IfStatement: + break; + + case NodeType::SourceFile: + { + auto Y = static_cast(X); + for (auto Element: Y->Elements) { + forwardDeclare(Element) ; + } + break; + } + + case NodeType::LetDeclaration: + { + auto Y = static_cast(X); + + auto NewCtx = new InferContext(); + Y->Ctx = NewCtx; + std::cerr << Y << std::endl; + + Contexts.push_back(NewCtx); + + Type* Ty; + if (Y->TypeAssert) { + Ty = inferTypeExpression(Y->TypeAssert->TypeExpression); + } else { + Ty = createTypeVar(); + } + Y->Ty = Ty; + + if (Y->Body) { + switch (Y->Body->Type) { + case NodeType::LetExprBody: + break; + case NodeType::LetBlockBody: + { + auto Z = static_cast(Y->Body); + for (auto Element: Z->Elements) { + forwardDeclare(Element); + } + break; + } + default: + ZEN_UNREACHABLE + } + } + + Contexts.pop_back(); + + inferBindings(Y->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs); + + + break; + } + + default: + ZEN_UNREACHABLE + } - void Checker::infer(Node* X, InferContext& Ctx) { + } + + void Checker::infer(Node* X) { switch (X->Type) { @@ -159,7 +320,7 @@ namespace bolt { { auto Y = static_cast(X); for (auto Element: Y->Elements) { - infer(Element, Ctx); + infer(Element); } break; } @@ -169,10 +330,10 @@ namespace bolt { auto Y = static_cast(X); for (auto Part: Y->Parts) { if (Part->Test != nullptr) { - Ctx.addConstraint(new CEqual { BoolType, inferExpression(Part->Test, Ctx), Part->Test }); + addConstraint(new CEqual { BoolType, inferExpression(Part->Test), Part->Test }); } for (auto Element: Part->Elements) { - infer(Element, Ctx); + infer(Element); } } break; @@ -182,24 +343,18 @@ namespace bolt { { auto Y = static_cast(X); - auto NewCtx = new InferContext { Ctx }; - - Type* Ty; - if (Y->TypeAssert) { - Ty = inferTypeExpression(Y->TypeAssert->TypeExpression, *NewCtx); - } else { - Ty = createTypeVar(*NewCtx); - } + auto NewCtx = Y->Ctx; + Contexts.push_back(NewCtx); std::vector ParamTypes; Type* RetType; for (auto Param: Y->Params) { // TODO incorporate Param->TypeAssert or make it a kind of pattern - TVar* TV = createTypeVar(*NewCtx); + TVar* TV = createTypeVar(); TVSet NoTVs; ConstraintSet NoConstraints; - inferBindings(Param->Pattern, TV, *NewCtx, NoConstraints, NoTVs); + inferBindings(Param->Pattern, TV, NoConstraints, NoTVs); ParamTypes.push_back(TV); } @@ -208,16 +363,16 @@ namespace bolt { case NodeType::LetExprBody: { auto Z = static_cast(Y->Body); - RetType = inferExpression(Z->Expression, *NewCtx); + RetType = inferExpression(Z->Expression); break; } case NodeType::LetBlockBody: { auto Z = static_cast(Y->Body); - RetType = createTypeVar(*NewCtx); + RetType = createTypeVar(); NewCtx->ReturnType = RetType; for (auto Element: Z->Elements) { - infer(Element, *NewCtx); + infer(Element); } break; } @@ -225,12 +380,12 @@ namespace bolt { ZEN_UNREACHABLE } } else { - RetType = createTypeVar(*NewCtx); + RetType = createTypeVar(); } - NewCtx->addConstraint(new CEqual { Ty, new TArrow(ParamTypes, RetType), X }); + addConstraint(new CEqual { Y->Ty, new TArrow(ParamTypes, RetType), X }); - inferBindings(Y->Pattern, Ty, Ctx, NewCtx->Constraints, NewCtx->TVs); + Contexts.pop_back(); break; } @@ -240,19 +395,18 @@ namespace bolt { auto Y = static_cast(X); Type* ReturnType; if (Y->Expression) { - ReturnType = inferExpression(Y->Expression, Ctx); + ReturnType = inferExpression(Y->Expression); } else { ReturnType = new TTuple({}); } - ZEN_ASSERT(Ctx.ReturnType != nullptr); - Ctx.addConstraint(new CEqual { ReturnType, Ctx.ReturnType, X }); + addConstraint(new CEqual { ReturnType, getReturnType(), X }); break; } case NodeType::ExpressionStatement: { auto Y = static_cast(X); - inferExpression(Y->Expression, Ctx); + inferExpression(Y->Expression); break; } @@ -263,13 +417,13 @@ namespace bolt { } - TVar* Checker::createTypeVar(InferContext& Ctx) { + TVar* Checker::createTypeVar() { auto TV = new TVar(nextTypeVarId++); - Ctx.TVs.emplace(TV); + Contexts.back()->TVs.emplace(TV); return TV; } - Type* Checker::instantiate(Scheme& S, InferContext& Ctx, Node* Source) { + Type* Checker::instantiate(Scheme& S, Node* Source) { switch (S.getKind()) { @@ -278,47 +432,46 @@ namespace bolt { auto& F = S.as(); TVSub Sub; - if (F.TVs) { - for (auto TV: *F.TVs) { - Sub[TV] = createTypeVar(Ctx); - } + for (auto TV: *F.TVs) { + Sub[TV] = createTypeVar(); } - if (F.Constraints) { + for (auto Constraint: *F.Constraints) { - for (auto Constraint: *F.Constraints) { + auto NewConstraint = Constraint->substitute(Sub); - auto NewConstraint = Constraint->substitute(Sub); - - // This makes error messages prettier by relating the typing failure - // to the call site rather than the definition. - if (NewConstraint->getKind() == ConstraintKind::Equal) { - static_cast(NewConstraint)->Source = Source; - } - - Ctx.addConstraint(NewConstraint); + // This makes error messages prettier by relating the typing failure + // to the call site rather than the definition. + if (NewConstraint->getKind() == ConstraintKind::Equal) { + static_cast(NewConstraint)->Source = Source; } + + addConstraint(NewConstraint); } - return F.Type->substitute(Sub); + // FIXME substitute should always clone if we set MaxDepth + auto NewType = F.Type->substitute(Sub); + //NewType->MaxDepth = std::max(static_cast(Contexts.size()-1), F.Type->MaxDepth); + return NewType; } } } - Type* Checker::inferTypeExpression(TypeExpression* X, InferContext& Ctx) { + Type* Checker::inferTypeExpression(TypeExpression* X) { switch (X->Type) { case NodeType::ReferenceTypeExpression: { auto Y = static_cast(X); - auto Ty = Ctx.lookupMono(Y->Name->Name->Text); + auto Ty = lookupMono(Y->Name->Name->Text); if (Ty == nullptr) { DE.add(Y->Name->Name->Text, Y->Name->Name); return new TAny(); } + Mapping[X] = Ty; return Ty; } @@ -327,10 +480,12 @@ namespace bolt { auto Y = static_cast(X); std::vector ParamTypes; for (auto ParamType: Y->ParamTypes) { - ParamTypes.push_back(inferTypeExpression(ParamType, Ctx)); + ParamTypes.push_back(inferTypeExpression(ParamType)); } - auto ReturnType = inferTypeExpression(Y->ReturnType, Ctx); - return new TArrow(ParamTypes, ReturnType); + auto ReturnType = inferTypeExpression(Y->ReturnType); + auto Ty = new TArrow(ParamTypes, ReturnType); + Mapping[X] = Ty; + return Ty; } default: @@ -339,7 +494,7 @@ namespace bolt { } } - Type* Checker::inferExpression(Expression* X, InferContext& Ctx) { + Type* Checker::inferExpression(Expression* X) { switch (X->Type) { @@ -349,15 +504,16 @@ namespace bolt { Type* Ty = nullptr; switch (Y->Token->Type) { case NodeType::IntegerLiteral: - Ty = Ctx.lookupMono("Int"); + Ty = lookupMono("Int"); break; case NodeType::StringLiteral: - Ty = Ctx.lookupMono("String"); + Ty = lookupMono("String"); break; default: ZEN_UNREACHABLE } ZEN_ASSERT(Ty != nullptr); + Mapping[X] = Ty; return Ty; } @@ -365,44 +521,58 @@ namespace bolt { { auto Y = static_cast(X); ZEN_ASSERT(Y->Name->ModulePath.empty()); - auto Scm = Ctx.lookup(Y->Name->Name->Text); + auto Ctx = lookupCall(Y, Y->Name->getSymbolPath()); + if (Ctx) { + return Ctx->ReturnType; + } + auto Scm = lookup(Y->Name->Name->Text); if (Scm == nullptr) { DE.add(Y->Name->Name->Text, Y->Name); return new TAny(); } - return instantiate(*Scm, Ctx, X); + auto Ty = instantiate(*Scm, X); + Mapping[X] = Ty; + return Ty; } case NodeType::CallExpression: { auto Y = static_cast(X); - auto OpTy = inferExpression(Y->Function, Ctx); - auto RetType = createTypeVar(Ctx); + auto OpTy = inferExpression(Y->Function); + auto RetType = createTypeVar(); std::vector ArgTypes; for (auto Arg: Y->Args) { - ArgTypes.push_back(inferExpression(Arg, Ctx)); + ArgTypes.push_back(inferExpression(Arg)); } - Ctx.addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X }); + addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X }); + Mapping[X] = RetType; return RetType; } case NodeType::InfixExpression: { auto Y = static_cast(X); - auto Scm = Ctx.lookup(Y->Operator->getText()); + auto Scm = lookup(Y->Operator->getText()); if (Scm == nullptr) { DE.add(Y->Operator->getText(), Y->Operator); return new TAny(); } - auto OpTy = instantiate(*Scm, Ctx, Y->Operator); - auto RetTy = createTypeVar(Ctx); + auto OpTy = instantiate(*Scm, Y->Operator); + auto RetTy = createTypeVar(); std::vector ArgTys; - ArgTys.push_back(inferExpression(Y->LHS, Ctx)); - ArgTys.push_back(inferExpression(Y->RHS, Ctx)); - Ctx.addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X }); + ArgTys.push_back(inferExpression(Y->LHS)); + ArgTys.push_back(inferExpression(Y->RHS)); + addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X }); + Mapping[X] = RetTy; return RetTy; } + case NodeType::NestedExpression: + { + auto Y = static_cast(X); + return inferExpression(Y->Inner); + } + default: ZEN_UNREACHABLE @@ -410,12 +580,12 @@ namespace bolt { } - void Checker::inferBindings(Pattern* Pattern, Type* Type, InferContext& Ctx, ConstraintSet& Constraints, TVSet& TVs) { + void Checker::inferBindings(Pattern* Pattern, Type* Type, ConstraintSet& Constraints, TVSet& TVs) { switch (Pattern->Type) { case NodeType::BindPattern: - Ctx.addBinding(static_cast(Pattern)->Name->Text, Forall(TVs, Constraints, Type)); + addBinding(static_cast(Pattern)->Name->Text, Forall(TVs, Constraints, Type)); break; default: @@ -424,26 +594,33 @@ namespace bolt { } } - void Checker::check(SourceFile *SF) { - InferContext Toplevel; - Toplevel.addBinding("String", Forall(StringType)); - Toplevel.addBinding("Int", Forall(IntType)); - Toplevel.addBinding("Bool", Forall(BoolType)); - Toplevel.addBinding("True", Forall(BoolType)); - Toplevel.addBinding("False", Forall(BoolType)); - Toplevel.addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType))); - Toplevel.addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType))); - Toplevel.addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType))); - Toplevel.addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType))); - infer(SF, Toplevel); - solve(new CMany(Toplevel.Constraints)); + TVSub Checker::check(SourceFile *SF) { + Contexts.push_back(new InferContext {}); + ConstraintSet NoConstraints; + addBinding("String", Forall(StringType)); + addBinding("Int", Forall(IntType)); + addBinding("Bool", Forall(BoolType)); + addBinding("True", Forall(BoolType)); + addBinding("False", Forall(BoolType)); + auto A = createTypeVar(); + TVSet SingleA { A }; + addBinding("==", Forall(SingleA, NoConstraints, new TArrow({ A, A }, BoolType))); + addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType))); + addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType))); + addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType))); + addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType))); + forwardDeclare(SF); + infer(SF); + TVSub Solution; + solve(new CMany(Contexts.front()->Constraints), Solution); + Contexts.pop_back(); + return Solution; } - void Checker::solve(Constraint* Constraint) { + void Checker::solve(Constraint* Constraint, TVSub& Solution) { std::stack Queue; Queue.push(Constraint); - TVSub Solution; while (!Queue.empty()) { @@ -459,7 +636,7 @@ namespace bolt { case ConstraintKind::Many: { auto Y = static_cast(Constraint); - for (auto Constraint: Y->Constraints) { + for (auto Constraint: Y->Elements) { Queue.push(Constraint); } break; @@ -468,7 +645,7 @@ namespace bolt { case ConstraintKind::Equal: { auto Y = static_cast(Constraint); - std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl; + //std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl; if (!unify(Y->Left, Y->Right, Solution)) { DE.add(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source); } @@ -530,6 +707,17 @@ namespace bolt { return unify(Y->ReturnType, Z->ReturnType, Solution); } + if (A->getKind() == TypeKind::Arrow) { + auto Y = static_cast(A); + if (Y->ParamTypes.empty()) { + return unify(Y->ReturnType, B, Solution); + } + } + + if (B->getKind() == TypeKind::Arrow) { + return unify(B, A, Solution); + } + if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) { auto Y = static_cast(A); auto Z = static_cast(B); @@ -565,5 +753,22 @@ namespace bolt { return false; } + InferContext* Checker::lookupCall(Node* Source, SymbolPath Path) { + auto Def = Source->getScope()->lookup(Path); + auto Match = CallGraph.find(Def); + if (Match == CallGraph.end()) { + return nullptr; + } + return Match->second; + } + + Type* Checker::getType(Node *Node, const TVSub &Solution) { + auto Match = Mapping.find(Node); + if (Match == Mapping.end()) { + return nullptr; + } + return Match->second->substitute(Solution); + } + } diff --git a/src/Parser.cc b/src/Parser.cc index 702549ed7..630407254 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -148,6 +148,13 @@ namespace bolt { auto Name = parseQualifiedName(); return new ReferenceExpression(Name); } + case NodeType::LParen: + { + Tokens.get(); + auto E = parseExpression(); + auto T2 = static_cast(expectToken(NodeType::RParen)); + return new NestedExpression(static_cast(T0), E, T2); + } case NodeType::IntegerLiteral: case NodeType::StringLiteral: Tokens.get(); @@ -162,7 +169,7 @@ namespace bolt { std::vector Args; for (;;) { auto T1 = Tokens.peek(); - if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) { + if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::RParen || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) { break; } Args.push_back(parsePrimitiveExpression());