Improve support for typechecking mutually recursive functions

This commit is contained in:
Sam Vervaeck 2024-07-11 21:05:14 +02:00
parent 1d091c58d0
commit 5cd4cc3e84
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY

View file

@ -394,9 +394,8 @@ ConstraintSet Checker::inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaratio
Out.push_back(new CTypesEqual(RetTy, BodyTy, Body)); Out.push_back(new CTypesEqual(RetTy, BodyTy, Body));
} }
// Env.add(D->getNameAsString(), Ty, SymbolKind::Var); // inferMany() will have set the type of the node to a fresh type variable.
Out.push_back(new CTypesEqual { D->getType(), Ty, D });
D->setType(Ty);
return Out; return Out;
} }
@ -537,9 +536,20 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
} }
} }
for (auto Nodes: zen::toposort(G)) { for (auto Mutual: zen::toposort(G)) {
ConstraintSet Out; ConstraintSet Out;
for (auto N: Nodes) {
for (auto N: Mutual) {
if (isa<FunctionDeclaration>(N)) {
auto Func = static_cast<FunctionDeclaration*>(N);
Type* Ty = createTVar();
Func->setType(Ty);
Env.add(Func->getNameAsString(), Ty, SymbolKind::Var);
}
}
for (auto N: Mutual) {
if (isa<FunctionDeclaration>(N)) { if (isa<FunctionDeclaration>(N)) {
mergeTo(Out, inferFunctionDeclaration(Env, static_cast<FunctionDeclaration*>(N))); mergeTo(Out, inferFunctionDeclaration(Env, static_cast<FunctionDeclaration*>(N)));
} else if (isa<VariableDeclaration>(N)) { } else if (isa<VariableDeclaration>(N)) {
@ -548,8 +558,10 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
} }
solve(Out); solve(Out);
for (auto N: Nodes) {
for (auto N: Mutual) {
if (isa<FunctionDeclaration>(N)) { if (isa<FunctionDeclaration>(N)) {
auto Func = static_cast<FunctionDeclaration*>(N); auto Func = static_cast<FunctionDeclaration*>(N);
auto Unbound = getUnbound(Env, Func->getType()); auto Unbound = getUnbound(Env, Func->getType());
@ -560,6 +572,7 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
); );
} }
} }
} }
ConstraintSet Out; ConstraintSet Out;