diff --git a/src/Checker.cc b/src/Checker.cc index f917e6fc4..190c64de9 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -109,20 +109,36 @@ std::tuple Checker::inferExpr(TypeEnv& Env, Expression* Ex switch (Expr->getKind()) { + case NodeKind::FunctionExpression: + { + auto E = static_cast(Expr); + Type* NewRetTy = createTVar(); + Ty = NewRetTy; + TypeEnv NewEnv { Env }; + for (auto P: E->getParameters()) { + auto TV = createTVar(); + visitPattern(P, TV, NewEnv); + Ty = new TFun(TV, Ty); + } + auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy); + mergeTo(Out, ExprOut); + Out.push_back(new CTypesEqual { ExprTy, NewRetTy, E }); + break; + } + case NodeKind::BlockExpression: { auto E = static_cast(Expr); auto N = E->Elements.size(); for (std::size_t I = 0; I+1 < N; ++I) { auto Element = E->Elements[I]; - - auto CC = inferElement(Env, Element, RetTy); - mergeTo(Out, CC); + auto ElementOut = inferElement(Env, Element, RetTy); + mergeTo(Out, ElementOut); } auto Last = E->Elements[N-1]; - auto [CC, ResTy] = inferExpr(Env, cast(Last), RetTy); - mergeTo(Out, CC); - Ty = ResTy; + auto [LastOut, LastTy] = inferExpr(Env, cast(Last), RetTy); + mergeTo(Out, LastOut); + Ty = LastTy; break; } @@ -522,15 +538,33 @@ ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected, } break; default: - break; + ZEN_UNREACHABLE; } + goto fallback; } - // TODO - // case NodeKind::FunctionExpression: - + case NodeKind::FunctionExpression: + { + auto E = static_cast(Expr); + // FIXME save RetTy on the node and re-use it in this function? + if (Expected->getKind() == TypeKind::Fun) { + TypeEnv NewEnv { Env }; + TFun* Ty = static_cast(Expected); + for (auto P: E->getParameters()) { + visitPattern(P, Ty->getLeft(), NewEnv); + if (Ty->getRight()->getKind() != TypeKind::Fun) { + goto fallback; + } + Ty = static_cast(Ty->getRight()); + } + return checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight()); + } + goto fallback; + } + default: { +fallback: auto [Out, Actual] = inferExpr(Env, Expr, RetTy); Out.push_back(new CTypesEqual(Actual, Expected, Expr)); return Out;