Add support for type-checking function expressions

This commit is contained in:
Sam Vervaeck 2024-07-11 09:09:27 +02:00
parent 556fc28eb7
commit 1d2306513e
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY

View file

@ -109,20 +109,36 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
switch (Expr->getKind()) { switch (Expr->getKind()) {
case NodeKind::FunctionExpression:
{
auto E = static_cast<FunctionExpression*>(Expr);
Type* NewRetTy = createTVar();
Ty = NewRetTy;
TypeEnv NewEnv { Env };
for (auto P: E->getParameters()) {
auto TV = createTVar();
visitPattern(P, TV, NewEnv);
Ty = new TFun(TV, Ty);
}
auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy);
mergeTo(Out, ExprOut);
Out.push_back(new CTypesEqual { ExprTy, NewRetTy, E });
break;
}
case NodeKind::BlockExpression: case NodeKind::BlockExpression:
{ {
auto E = static_cast<BlockExpression*>(Expr); auto E = static_cast<BlockExpression*>(Expr);
auto N = E->Elements.size(); auto N = E->Elements.size();
for (std::size_t I = 0; I+1 < N; ++I) { for (std::size_t I = 0; I+1 < N; ++I) {
auto Element = E->Elements[I]; auto Element = E->Elements[I];
auto ElementOut = inferElement(Env, Element, RetTy);
auto CC = inferElement(Env, Element, RetTy); mergeTo(Out, ElementOut);
mergeTo(Out, CC);
} }
auto Last = E->Elements[N-1]; auto Last = E->Elements[N-1];
auto [CC, ResTy] = inferExpr(Env, cast<Expression>(Last), RetTy); auto [LastOut, LastTy] = inferExpr(Env, cast<Expression>(Last), RetTy);
mergeTo(Out, CC); mergeTo(Out, LastOut);
Ty = ResTy; Ty = LastTy;
break; break;
} }
@ -522,15 +538,33 @@ ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected,
} }
break; break;
default: default:
break; ZEN_UNREACHABLE;
} }
goto fallback;
} }
// TODO case NodeKind::FunctionExpression:
// case NodeKind::FunctionExpression: {
auto E = static_cast<FunctionExpression*>(Expr);
// FIXME save RetTy on the node and re-use it in this function?
if (Expected->getKind() == TypeKind::Fun) {
TypeEnv NewEnv { Env };
TFun* Ty = static_cast<TFun*>(Expected);
for (auto P: E->getParameters()) {
visitPattern(P, Ty->getLeft(), NewEnv);
if (Ty->getRight()->getKind() != TypeKind::Fun) {
goto fallback;
}
Ty = static_cast<TFun*>(Ty->getRight());
}
return checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight());
}
goto fallback;
}
default: default:
{ {
fallback:
auto [Out, Actual] = inferExpr(Env, Expr, RetTy); auto [Out, Actual] = inferExpr(Env, Expr, RetTy);
Out.push_back(new CTypesEqual(Actual, Expected, Expr)); Out.push_back(new CTypesEqual(Actual, Expected, Expr));
return Out; return Out;