diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index f7de9c56c..d4e6cbdfb 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -1724,6 +1724,7 @@ namespace bolt { public: bool IsCycleActive = false; + bool Visited = false; InferContext* Ctx; class PubKeyword* PubKeyword; diff --git a/src/Checker.cc b/src/Checker.cc index a13a28a94..636ac9859 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -554,6 +554,7 @@ namespace bolt { // std::cerr << "infer " << Decl->getNameAsString() << std::endl; + auto OldCtx = ActiveContext; setContext(Decl->Ctx); std::vector ParamTypes; @@ -589,6 +590,7 @@ namespace bolt { makeEqual(Decl->getType(), TArrow::build(ParamTypes, RetType), Decl); + setContext(OldCtx); } void Checker::infer(Node* N) { @@ -658,18 +660,22 @@ namespace bolt { { // Function declarations are handled separately in inferLetDeclaration() auto Decl = static_cast(N); - if (!Decl->isVariable()) { - break; + if (Decl->isFunction() && !Decl->Visited) { + Decl->IsCycleActive = true; + Decl->Visited = true; + inferFunctionDeclaration(Decl); + Decl->IsCycleActive = false; + } else if (Decl->isVariable()) { + auto Ty = Decl->getType(); + if (Decl->Body) { + ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); + auto E = static_cast(Decl->Body); + auto Ty2 = inferExpression(E->Expression); + makeEqual(Ty, Ty2, Decl); + } + auto Ty3 = inferPattern(Decl->Pattern); + makeEqual(Ty, Ty3, Decl); } - auto Ty = Decl->getType(); - if (Decl->Body) { - ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); - auto E = static_cast(Decl->Body); - auto Ty2 = inferExpression(E->Expression); - makeEqual(Ty, Ty2, Decl); - } - auto Ty3 = inferPattern(Decl->Pattern); - makeEqual(Ty, Ty3, Decl); break; } @@ -811,6 +817,7 @@ namespace bolt { for (auto Arg: AppTE->Args) { Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly)); } + N->setType(Ty); return Ty; } @@ -946,18 +953,34 @@ namespace bolt { { auto Ref = static_cast(X); ZEN_ASSERT(Ref->ModulePath.empty()); + if (Ref->Name->is()) { + auto Scm = lookup(Ref->Name->getCanonicalText()); + if (!Scm) { + DE.add(Ref->Name->getCanonicalText(), Ref->Name); + Ty = createTypeVar(); + break; + } + Ty = instantiate(Scm, X); + break; + } auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); - if (Target && llvm::isa(Target)) { + if (!Target) { + DE.add(Ref->Name->getCanonicalText(), Ref->Name); + Ty = createTypeVar(); + break; + } + if (Target->getKind() == NodeKind::LetDeclaration) { auto Let = static_cast(Target); if (Let->IsCycleActive) { - return Let->getType(); + Ty = Let->getType(); + break; + } + if (!Let->Visited) { + infer(Let); } } auto Scm = lookup(Ref->Name->getCanonicalText()); - if (Scm == nullptr) { - DE.add(Ref->Name->getCanonicalText(), Ref->Name); - return createTypeVar(); - } + ZEN_ASSERT(Scm); Ty = instantiate(Scm, X); break; } @@ -1209,29 +1232,6 @@ namespace bolt { forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); } } - for (auto Nodes: SCCs) { - for (auto N: Nodes) { - if (N->getKind() != NodeKind::LetDeclaration) { - continue; - } - auto Decl = static_cast(N); - Decl->IsCycleActive = true; - } - for (auto N: Nodes) { - if (N->getKind() != NodeKind::LetDeclaration) { - continue; - } - auto Decl = static_cast(N); - inferFunctionDeclaration(Decl); - } - for (auto N: Nodes) { - if (N->getKind() != NodeKind::LetDeclaration) { - continue; - } - auto Decl = static_cast(N); - Decl->IsCycleActive = false; - } - } setContext(SF->Ctx); infer(SF);