diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index 4d59b0069..03ade6e25 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -117,7 +117,7 @@ public: Type* instantiate(TypeScheme* Scm); - void visitPattern(Pattern* P, Type* Ty, TypeEnv& Out); + ConstraintSet visitPattern(Pattern* P, Type* Ty, TypeEnv& Out); ConstraintSet inferSourceFile(TypeEnv& Env, SourceFile* SF); diff --git a/src/Checker.cc b/src/Checker.cc index 190c64de9..225c6e9c1 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -6,8 +6,10 @@ #include "bolt/CST.hpp" #include "bolt/Type.hpp" #include "bolt/Diagnostics.hpp" +#include #include #include +#include #include "bolt/Checker.hpp" namespace bolt { @@ -109,6 +111,42 @@ std::tuple Checker::inferExpr(TypeEnv& Env, Expression* Ex switch (Expr->getKind()) { + case NodeKind::MatchExpression: + { + auto E = static_cast(Expr); + Type* MatchTy; + if (E->hasValue()) { + auto [ValOut, ValTy] = inferExpr(Env, E->getValue(), RetTy); + mergeTo(Out, ValOut); + MatchTy = ValTy; + } else { + MatchTy = createTVar(); + } + Ty = createTVar(); + for (auto Case: E->Cases) { + TypeEnv NewEnv { Env }; + auto PattOut = visitPattern(Case->Pattern, MatchTy, NewEnv); + mergeTo(Out, PattOut); + auto [ExprOut, ExprTy] = inferExpr(NewEnv, Case->Expression, RetTy); + mergeTo(Out, ExprOut); + Out.push_back(new CTypesEqual { ExprTy, Ty, Case->Expression }); + } + if (E->Value) { + auto ParamTy = createTVar(); + Ty = new TFun(ParamTy, Ty); + } + break; + } + + case NodeKind::NestedExpression: + { + auto E = static_cast(Expr); + auto [ExprOut, ExprTy] = inferExpr(Env, E->Inner, RetTy); + mergeTo(Out, ExprOut); + Ty = ExprTy; + break; + } + case NodeKind::FunctionExpression: { auto E = static_cast(Expr); @@ -117,7 +155,8 @@ std::tuple Checker::inferExpr(TypeEnv& Env, Expression* Ex TypeEnv NewEnv { Env }; for (auto P: E->getParameters()) { auto TV = createTVar(); - visitPattern(P, TV, NewEnv); + auto ParamOut = visitPattern(P, TV, NewEnv); + mergeTo(Out, ParamOut); Ty = new TFun(TV, Ty); } auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy); @@ -235,18 +274,45 @@ std::tuple Checker::inferExpr(TypeEnv& Env, Expression* Ex return { Out, Ty }; } -void Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& Out) { +ConstraintSet Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& ToInsert) { + + ConstraintSet Out; + switch (P->getKind()) { + case NodeKind::BindPattern: { auto Q = static_cast(P); // TODO Make a TypedNode out of a Pattern? - Out.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var); + ToInsert.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var); break; } + + case NodeKind::LiteralPattern: + { + auto Lit = static_cast(P); + Type* LitTy; + switch (Lit->Literal->getKind()) { + case NodeKind::StringLiteral: + LitTy = getStringType(); + break; + case NodeKind::IntegerLiteral: + LitTy = getIntType(); + break; + default: + ZEN_UNREACHABLE + } + Out.push_back(new CTypesEqual { Ty, LitTy, Lit }); + break; + } + default: ZEN_UNREACHABLE + } + + return Out; + } std::tuple Checker::inferTypeExpr(TypeEnv& Env, TypeExpression* TE) { @@ -296,7 +362,8 @@ ConstraintSet Checker::inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaratio for (auto It = Params.end(); It-- != Params.begin(); ) { auto Param = *It; auto ParamTy = createTVar(); - visitPattern(Param->Pattern, ParamTy, NewEnv); + auto ParamOut = visitPattern(Param->Pattern, ParamTy, NewEnv); + mergeTo(Out, ParamOut); Ty = new TFun(ParamTy, Ty); } @@ -418,11 +485,13 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector& Elements, Typ Node* From; void visitReferenceExpression(ReferenceExpression* E) { auto To = E->getScope()->lookup(E->getSymbolPath()); - if (To != nullptr) { + if (To) { if (isa(To)) { To = To->Parent; } - G.add_edge(From, To); + if (isa(To) || isa(To)) { + G.add_edge(From, To); + } } } } V { {}, G, From }; @@ -433,16 +502,16 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector& Elements, Typ for (auto Element: Elements) { if (isa(Element)) { - auto M = static_cast(Element); - G.add_vertex(Element); - if (M->hasBody()) { - populate(M, M->getBody()); + auto Decl = static_cast(Element); + G.add_vertex(Decl); + if (Decl->hasBody()) { + populate(Decl, Decl->getBody()); } } else if (isa(Element)) { - auto M = static_cast(Element); - G.add_vertex(Element); - if (M->hasExpression()) { - populate(M, M->getExpression()); + auto Decl = static_cast(Element); + G.add_vertex(Decl); + if (Decl->hasExpression()) { + populate(Decl, Decl->getExpression()); } } else { Stmts.push_back(Element); @@ -545,19 +614,23 @@ ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected, case NodeKind::FunctionExpression: { + ConstraintSet Out; auto E = static_cast(Expr); // FIXME save RetTy on the node and re-use it in this function? if (Expected->getKind() == TypeKind::Fun) { TypeEnv NewEnv { Env }; TFun* Ty = static_cast(Expected); for (auto P: E->getParameters()) { - visitPattern(P, Ty->getLeft(), NewEnv); + auto ParamOut = visitPattern(P, Ty->getLeft(), NewEnv); + mergeTo(Out, ParamOut); if (Ty->getRight()->getKind() != TypeKind::Fun) { goto fallback; } Ty = static_cast(Ty->getRight()); } - return checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight()); + auto ExprOut = checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight()); + mergeTo(Out, ExprOut); + return Out; } goto fallback; } @@ -619,6 +692,8 @@ void Checker::unifyTypeType(Type* A, Type* B, Node* N) { void Checker::run(SourceFile* SF) { TypeEnv Env; + auto A = createTVar(); + auto B = createTVar(); Env.add("Int", getIntType(), SymbolKind::Type); Env.add("Bool", getBoolType(), SymbolKind::Type); Env.add("String", getStringType(), SymbolKind::Type); @@ -626,6 +701,7 @@ void Checker::run(SourceFile* SF) { Env.add("False", getBoolType(), SymbolKind::Var); Env.add("+", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var); Env.add("-", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var); + Env.add("$", new TypeScheme({ A, B }, new TFun(new TFun(A, B), new TFun(A, B))), SymbolKind::Var); auto Out = inferSourceFile(Env, SF); solve(Out); }