Add support for type-checking function expressions
This commit is contained in:
parent
556fc28eb7
commit
1d2306513e
1 changed files with 44 additions and 10 deletions
|
@ -109,20 +109,36 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
|
|||
|
||||
switch (Expr->getKind()) {
|
||||
|
||||
case NodeKind::FunctionExpression:
|
||||
{
|
||||
auto E = static_cast<FunctionExpression*>(Expr);
|
||||
Type* NewRetTy = createTVar();
|
||||
Ty = NewRetTy;
|
||||
TypeEnv NewEnv { Env };
|
||||
for (auto P: E->getParameters()) {
|
||||
auto TV = createTVar();
|
||||
visitPattern(P, TV, NewEnv);
|
||||
Ty = new TFun(TV, Ty);
|
||||
}
|
||||
auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy);
|
||||
mergeTo(Out, ExprOut);
|
||||
Out.push_back(new CTypesEqual { ExprTy, NewRetTy, E });
|
||||
break;
|
||||
}
|
||||
|
||||
case NodeKind::BlockExpression:
|
||||
{
|
||||
auto E = static_cast<BlockExpression*>(Expr);
|
||||
auto N = E->Elements.size();
|
||||
for (std::size_t I = 0; I+1 < N; ++I) {
|
||||
auto Element = E->Elements[I];
|
||||
|
||||
auto CC = inferElement(Env, Element, RetTy);
|
||||
mergeTo(Out, CC);
|
||||
auto ElementOut = inferElement(Env, Element, RetTy);
|
||||
mergeTo(Out, ElementOut);
|
||||
}
|
||||
auto Last = E->Elements[N-1];
|
||||
auto [CC, ResTy] = inferExpr(Env, cast<Expression>(Last), RetTy);
|
||||
mergeTo(Out, CC);
|
||||
Ty = ResTy;
|
||||
auto [LastOut, LastTy] = inferExpr(Env, cast<Expression>(Last), RetTy);
|
||||
mergeTo(Out, LastOut);
|
||||
Ty = LastTy;
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -522,15 +538,33 @@ ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected,
|
|||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
ZEN_UNREACHABLE;
|
||||
}
|
||||
goto fallback;
|
||||
}
|
||||
|
||||
// TODO
|
||||
// case NodeKind::FunctionExpression:
|
||||
|
||||
case NodeKind::FunctionExpression:
|
||||
{
|
||||
auto E = static_cast<FunctionExpression*>(Expr);
|
||||
// FIXME save RetTy on the node and re-use it in this function?
|
||||
if (Expected->getKind() == TypeKind::Fun) {
|
||||
TypeEnv NewEnv { Env };
|
||||
TFun* Ty = static_cast<TFun*>(Expected);
|
||||
for (auto P: E->getParameters()) {
|
||||
visitPattern(P, Ty->getLeft(), NewEnv);
|
||||
if (Ty->getRight()->getKind() != TypeKind::Fun) {
|
||||
goto fallback;
|
||||
}
|
||||
Ty = static_cast<TFun*>(Ty->getRight());
|
||||
}
|
||||
return checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight());
|
||||
}
|
||||
goto fallback;
|
||||
}
|
||||
|
||||
default:
|
||||
{
|
||||
fallback:
|
||||
auto [Out, Actual] = inferExpr(Env, Expr, RetTy);
|
||||
Out.push_back(new CTypesEqual(Actual, Expected, Expr));
|
||||
return Out;
|
||||
|
|
Loading…
Reference in a new issue