Introduce new algorithm for better handling of recursive functions

This commit is contained in:
Sam Vervaeck 2023-06-05 15:51:04 +02:00
parent 1853612284
commit 1147c751b9
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
2 changed files with 41 additions and 40 deletions

View file

@ -1724,6 +1724,7 @@ namespace bolt {
public: public:
bool IsCycleActive = false; bool IsCycleActive = false;
bool Visited = false;
InferContext* Ctx; InferContext* Ctx;
class PubKeyword* PubKeyword; class PubKeyword* PubKeyword;

View file

@ -554,6 +554,7 @@ namespace bolt {
// std::cerr << "infer " << Decl->getNameAsString() << std::endl; // std::cerr << "infer " << Decl->getNameAsString() << std::endl;
auto OldCtx = ActiveContext;
setContext(Decl->Ctx); setContext(Decl->Ctx);
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
@ -589,6 +590,7 @@ namespace bolt {
makeEqual(Decl->getType(), TArrow::build(ParamTypes, RetType), Decl); makeEqual(Decl->getType(), TArrow::build(ParamTypes, RetType), Decl);
setContext(OldCtx);
} }
void Checker::infer(Node* N) { void Checker::infer(Node* N) {
@ -658,18 +660,22 @@ namespace bolt {
{ {
// Function declarations are handled separately in inferLetDeclaration() // Function declarations are handled separately in inferLetDeclaration()
auto Decl = static_cast<LetDeclaration*>(N); auto Decl = static_cast<LetDeclaration*>(N);
if (!Decl->isVariable()) { if (Decl->isFunction() && !Decl->Visited) {
break; Decl->IsCycleActive = true;
Decl->Visited = true;
inferFunctionDeclaration(Decl);
Decl->IsCycleActive = false;
} else if (Decl->isVariable()) {
auto Ty = Decl->getType();
if (Decl->Body) {
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
auto E = static_cast<LetExprBody*>(Decl->Body);
auto Ty2 = inferExpression(E->Expression);
makeEqual(Ty, Ty2, Decl);
}
auto Ty3 = inferPattern(Decl->Pattern);
makeEqual(Ty, Ty3, Decl);
} }
auto Ty = Decl->getType();
if (Decl->Body) {
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
auto E = static_cast<LetExprBody*>(Decl->Body);
auto Ty2 = inferExpression(E->Expression);
makeEqual(Ty, Ty2, Decl);
}
auto Ty3 = inferPattern(Decl->Pattern);
makeEqual(Ty, Ty3, Decl);
break; break;
} }
@ -811,6 +817,7 @@ namespace bolt {
for (auto Arg: AppTE->Args) { for (auto Arg: AppTE->Args) {
Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly)); Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly));
} }
N->setType(Ty);
return Ty; return Ty;
} }
@ -946,18 +953,34 @@ namespace bolt {
{ {
auto Ref = static_cast<ReferenceExpression*>(X); auto Ref = static_cast<ReferenceExpression*>(X);
ZEN_ASSERT(Ref->ModulePath.empty()); ZEN_ASSERT(Ref->ModulePath.empty());
if (Ref->Name->is<IdentifierAlt>()) {
auto Scm = lookup(Ref->Name->getCanonicalText());
if (!Scm) {
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
Ty = createTypeVar();
break;
}
Ty = instantiate(Scm, X);
break;
}
auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); auto Target = Ref->getScope()->lookup(Ref->getSymbolPath());
if (Target && llvm::isa<LetDeclaration>(Target)) { if (!Target) {
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
Ty = createTypeVar();
break;
}
if (Target->getKind() == NodeKind::LetDeclaration) {
auto Let = static_cast<LetDeclaration*>(Target); auto Let = static_cast<LetDeclaration*>(Target);
if (Let->IsCycleActive) { if (Let->IsCycleActive) {
return Let->getType(); Ty = Let->getType();
break;
}
if (!Let->Visited) {
infer(Let);
} }
} }
auto Scm = lookup(Ref->Name->getCanonicalText()); auto Scm = lookup(Ref->Name->getCanonicalText());
if (Scm == nullptr) { ZEN_ASSERT(Scm);
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
return createTypeVar();
}
Ty = instantiate(Scm, X); Ty = instantiate(Scm, X);
break; break;
} }
@ -1209,29 +1232,6 @@ namespace bolt {
forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); forwardDeclareFunctionDeclaration(Decl, TVs, Constraints);
} }
} }
for (auto Nodes: SCCs) {
for (auto N: Nodes) {
if (N->getKind() != NodeKind::LetDeclaration) {
continue;
}
auto Decl = static_cast<LetDeclaration*>(N);
Decl->IsCycleActive = true;
}
for (auto N: Nodes) {
if (N->getKind() != NodeKind::LetDeclaration) {
continue;
}
auto Decl = static_cast<LetDeclaration*>(N);
inferFunctionDeclaration(Decl);
}
for (auto N: Nodes) {
if (N->getKind() != NodeKind::LetDeclaration) {
continue;
}
auto Decl = static_cast<LetDeclaration*>(N);
Decl->IsCycleActive = false;
}
}
setContext(SF->Ctx); setContext(SF->Ctx);
infer(SF); infer(SF);