Introduce new algorithm for better handling of recursive functions
This commit is contained in:
parent
1853612284
commit
1147c751b9
2 changed files with 41 additions and 40 deletions
|
@ -1724,6 +1724,7 @@ namespace bolt {
|
|||
public:
|
||||
|
||||
bool IsCycleActive = false;
|
||||
bool Visited = false;
|
||||
InferContext* Ctx;
|
||||
|
||||
class PubKeyword* PubKeyword;
|
||||
|
|
|
@ -554,6 +554,7 @@ namespace bolt {
|
|||
|
||||
// std::cerr << "infer " << Decl->getNameAsString() << std::endl;
|
||||
|
||||
auto OldCtx = ActiveContext;
|
||||
setContext(Decl->Ctx);
|
||||
|
||||
std::vector<Type*> ParamTypes;
|
||||
|
@ -589,6 +590,7 @@ namespace bolt {
|
|||
|
||||
makeEqual(Decl->getType(), TArrow::build(ParamTypes, RetType), Decl);
|
||||
|
||||
setContext(OldCtx);
|
||||
}
|
||||
|
||||
void Checker::infer(Node* N) {
|
||||
|
@ -658,18 +660,22 @@ namespace bolt {
|
|||
{
|
||||
// Function declarations are handled separately in inferLetDeclaration()
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
if (!Decl->isVariable()) {
|
||||
break;
|
||||
if (Decl->isFunction() && !Decl->Visited) {
|
||||
Decl->IsCycleActive = true;
|
||||
Decl->Visited = true;
|
||||
inferFunctionDeclaration(Decl);
|
||||
Decl->IsCycleActive = false;
|
||||
} else if (Decl->isVariable()) {
|
||||
auto Ty = Decl->getType();
|
||||
if (Decl->Body) {
|
||||
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
|
||||
auto E = static_cast<LetExprBody*>(Decl->Body);
|
||||
auto Ty2 = inferExpression(E->Expression);
|
||||
makeEqual(Ty, Ty2, Decl);
|
||||
}
|
||||
auto Ty3 = inferPattern(Decl->Pattern);
|
||||
makeEqual(Ty, Ty3, Decl);
|
||||
}
|
||||
auto Ty = Decl->getType();
|
||||
if (Decl->Body) {
|
||||
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
|
||||
auto E = static_cast<LetExprBody*>(Decl->Body);
|
||||
auto Ty2 = inferExpression(E->Expression);
|
||||
makeEqual(Ty, Ty2, Decl);
|
||||
}
|
||||
auto Ty3 = inferPattern(Decl->Pattern);
|
||||
makeEqual(Ty, Ty3, Decl);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -811,6 +817,7 @@ namespace bolt {
|
|||
for (auto Arg: AppTE->Args) {
|
||||
Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly));
|
||||
}
|
||||
N->setType(Ty);
|
||||
return Ty;
|
||||
}
|
||||
|
||||
|
@ -946,18 +953,34 @@ namespace bolt {
|
|||
{
|
||||
auto Ref = static_cast<ReferenceExpression*>(X);
|
||||
ZEN_ASSERT(Ref->ModulePath.empty());
|
||||
if (Ref->Name->is<IdentifierAlt>()) {
|
||||
auto Scm = lookup(Ref->Name->getCanonicalText());
|
||||
if (!Scm) {
|
||||
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
|
||||
Ty = createTypeVar();
|
||||
break;
|
||||
}
|
||||
Ty = instantiate(Scm, X);
|
||||
break;
|
||||
}
|
||||
auto Target = Ref->getScope()->lookup(Ref->getSymbolPath());
|
||||
if (Target && llvm::isa<LetDeclaration>(Target)) {
|
||||
if (!Target) {
|
||||
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
|
||||
Ty = createTypeVar();
|
||||
break;
|
||||
}
|
||||
if (Target->getKind() == NodeKind::LetDeclaration) {
|
||||
auto Let = static_cast<LetDeclaration*>(Target);
|
||||
if (Let->IsCycleActive) {
|
||||
return Let->getType();
|
||||
Ty = Let->getType();
|
||||
break;
|
||||
}
|
||||
if (!Let->Visited) {
|
||||
infer(Let);
|
||||
}
|
||||
}
|
||||
auto Scm = lookup(Ref->Name->getCanonicalText());
|
||||
if (Scm == nullptr) {
|
||||
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
|
||||
return createTypeVar();
|
||||
}
|
||||
ZEN_ASSERT(Scm);
|
||||
Ty = instantiate(Scm, X);
|
||||
break;
|
||||
}
|
||||
|
@ -1209,29 +1232,6 @@ namespace bolt {
|
|||
forwardDeclareFunctionDeclaration(Decl, TVs, Constraints);
|
||||
}
|
||||
}
|
||||
for (auto Nodes: SCCs) {
|
||||
for (auto N: Nodes) {
|
||||
if (N->getKind() != NodeKind::LetDeclaration) {
|
||||
continue;
|
||||
}
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
Decl->IsCycleActive = true;
|
||||
}
|
||||
for (auto N: Nodes) {
|
||||
if (N->getKind() != NodeKind::LetDeclaration) {
|
||||
continue;
|
||||
}
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
inferFunctionDeclaration(Decl);
|
||||
}
|
||||
for (auto N: Nodes) {
|
||||
if (N->getKind() != NodeKind::LetDeclaration) {
|
||||
continue;
|
||||
}
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
Decl->IsCycleActive = false;
|
||||
}
|
||||
}
|
||||
setContext(SF->Ctx);
|
||||
infer(SF);
|
||||
|
||||
|
|
Loading…
Reference in a new issue