#include "bolt/CSTVisitor.hpp" #include "zen/graph.hpp" #include "bolt/ByteString.hpp" #include "bolt/CST.hpp" #include "bolt/Type.hpp" #include "bolt/Diagnostics.hpp" #include #include #include #include #include "bolt/Checker.hpp" namespace bolt { static inline void mergeTo(ConstraintSet& Out, const ConstraintSet& Other) { for (auto C: Other) { Out.push_back(C); } } TypeScheme* TypeEnv::lookup(ByteString Name, SymbolKind Kind) { auto Curr = this; do { auto Match = Curr->Mapping.find(std::make_tuple(Name, Kind)); if (Match != Curr->Mapping.end()) { return Match->second; } Curr = Curr->Parent; } while (Curr); return nullptr; } void TypeEnv::add(ByteString Name, TypeScheme* Scm, SymbolKind Kind) { Mapping.emplace(std::make_tuple(Name, Kind), Scm); } void TypeEnv::add(ByteString Name, Type* Ty, SymbolKind Kind) { add(Name, new TypeScheme { {}, Ty }, Kind); } using TVSub = std::unordered_map; Type* substituteType(Type* Ty, const TVSub& Sub) { switch (Ty->getKind()) { case TypeKind::App: { auto A = static_cast(Ty); auto NewLeft = substituteType(A->getLeft(), Sub); auto NewRight = substituteType(A->getRight(), Sub); if (A->getLeft() == NewLeft && A->getRight() == NewRight) { return Ty; } return new TApp(NewLeft, NewRight); } case TypeKind::Con: return Ty; case TypeKind::Var: { auto NewTy = Ty->find(); if (NewTy->getKind() != TypeKind::Var) { return substituteType(NewTy, Sub); } auto Match = Sub.find(static_cast(NewTy)); return Match == Sub.end() ? NewTy : Match->second; } case TypeKind::Fun: { auto F = static_cast(Ty); auto NewLeft = substituteType(F->getLeft(), Sub); auto NewRight = substituteType(F->getRight(), Sub); if (F->getLeft() == NewLeft && F->getRight() == NewRight) { return Ty; } return new TFun(NewLeft, NewRight); } } } Checker::Checker(DiagnosticEngine& DE): DE(DE) { IntType = new TCon("Int"); BoolType = new TCon("Bool"); StringType = new TCon("String"); UnitType = new TCon("()"); } Type* Checker::instantiate(TypeScheme* Scm) { TVSub Sub; for (auto TV: Scm->Unbound) { auto Fresh = createTVar(); Sub[TV] = Fresh; } return substituteType(Scm->getType(), Sub); } std::tuple Checker::inferExpr(TypeEnv& Env, Expression* Expr, Type* RetTy) { ConstraintSet Out; Type* Ty; for (auto Ann: Expr->Annotations) { if (Ann->getKind() == NodeKind::TypeAssertAnnotation) { auto [AnnOut, AnnTy] = inferTypeExpr(Env, static_cast(Ann)->getTypeExpression()); mergeTo(Out, AnnOut); } } switch (Expr->getKind()) { case NodeKind::MatchExpression: { auto E = static_cast(Expr); Type* MatchTy; if (E->hasValue()) { auto [ValOut, ValTy] = inferExpr(Env, E->getValue(), RetTy); mergeTo(Out, ValOut); MatchTy = ValTy; } else { MatchTy = createTVar(); } Ty = createTVar(); for (auto Case: E->Cases) { TypeEnv NewEnv { Env }; auto PattOut = visitPattern(Case->Pattern, MatchTy, NewEnv); mergeTo(Out, PattOut); auto [ExprOut, ExprTy] = inferExpr(NewEnv, Case->Expression, RetTy); mergeTo(Out, ExprOut); Out.push_back(new CTypesEqual { ExprTy, Ty, Case->Expression }); } if (!E->Value) { auto ParamTy = createTVar(); Ty = new TFun(ParamTy, Ty); } break; } case NodeKind::NestedExpression: { auto E = static_cast(Expr); auto [ExprOut, ExprTy] = inferExpr(Env, E->Inner, RetTy); mergeTo(Out, ExprOut); Ty = ExprTy; break; } case NodeKind::FunctionExpression: { auto E = static_cast(Expr); Type* NewRetTy = createTVar(); Ty = NewRetTy; TypeEnv NewEnv { Env }; for (auto P: E->getParameters()) { auto TV = createTVar(); auto ParamOut = visitPattern(P, TV, NewEnv); mergeTo(Out, ParamOut); Ty = new TFun(TV, Ty); } auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy); mergeTo(Out, ExprOut); Out.push_back(new CTypesEqual { ExprTy, NewRetTy, E }); break; } case NodeKind::BlockExpression: { auto E = static_cast(Expr); auto N = E->Elements.size(); for (std::size_t I = 0; I+1 < N; ++I) { auto Element = E->Elements[I]; auto ElementOut = inferElement(Env, Element, RetTy); mergeTo(Out, ElementOut); } auto Last = E->Elements[N-1]; auto [LastOut, LastTy] = inferExpr(Env, cast(Last), RetTy); mergeTo(Out, LastOut); Ty = LastTy; break; } case NodeKind::ReferenceExpression: { auto E = static_cast(Expr); auto Name = E->Name.getCanonicalText(); auto Match = Env.lookup(Name, SymbolKind::Var); if (Match == nullptr) { DE.add(Name, E->Name); Ty = createTVar(); } else { Ty = instantiate(Match); } break; } case NodeKind::LiteralExpression: { auto E = static_cast(Expr); switch (E->Token ->getKind()) { case NodeKind::IntegerLiteral: Ty = getIntType(); break; case NodeKind::StringLiteral: Ty = getStringType(); break; default: ZEN_UNREACHABLE } break; } case NodeKind::CallExpression: { auto E = static_cast(Expr); auto RetTy = createTVar(); Type* FunTy = RetTy; for (auto It = E->Args.end(); It-- != E->Args.begin();) { auto [ArgOut, ArgTy] = inferExpr(Env, *It, RetTy); mergeTo(Out, ArgOut); FunTy = new TFun(ArgTy, FunTy); } auto FunOut = checkExpr(Env, E->Function, FunTy, RetTy); mergeTo(Out, FunOut); Ty = RetTy; break; } case NodeKind::InfixExpression: { auto E = static_cast(Expr); auto [LeftOut, LeftTy] = inferExpr(Env, E->Left, RetTy); mergeTo(Out, LeftOut); auto [RightOut, RightTy] = inferExpr(Env, E->Right, RetTy); mergeTo(Out, RightOut); auto Name = E->Operator.getCanonicalText(); auto Match = Env.lookup(Name, SymbolKind::Var); if (Match == nullptr) { DE.add(Name, E->Operator); return { Out, createTVar() }; } auto RetTy = createTVar(); auto FunTy = new TFun(LeftTy, new TFun(RightTy, RetTy)); Out.push_back(new CTypesEqual(FunTy, instantiate(Match), E)); Ty = RetTy; break; } case NodeKind::ReturnExpression: { auto E = static_cast(Expr); if (E->hasExpression()) { auto [ValOut, ValTy] = inferExpr(Env, E->getExpression(), RetTy); mergeTo(Out, ValOut); // Since evaluation stops at the return expression, it can be matched with any type. Out.push_back(new CTypesEqual { ValTy, RetTy, E }); } else { Out.push_back(new CTypesEqual { getUnitType(), RetTy, E }); } Ty = createTVar(); break; } // TODO LambdaExpression default: ZEN_UNREACHABLE } Expr->setType(Ty); return { Out, Ty }; } ConstraintSet Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& ToInsert) { ConstraintSet Out; switch (P->getKind()) { case NodeKind::BindPattern: { auto Q = static_cast(P); // TODO Make a TypedNode out of a Pattern? ToInsert.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var); break; } case NodeKind::LiteralPattern: { auto Lit = static_cast(P); Type* LitTy; switch (Lit->Literal->getKind()) { case NodeKind::StringLiteral: LitTy = getStringType(); break; case NodeKind::IntegerLiteral: LitTy = getIntType(); break; default: ZEN_UNREACHABLE } Out.push_back(new CTypesEqual { Ty, LitTy, Lit }); break; } default: ZEN_UNREACHABLE } return Out; } std::tuple Checker::inferTypeExpr(TypeEnv& Env, TypeExpression* TE) { ConstraintSet Out; Type* Ty; switch (TE->getKind()) { case NodeKind::ReferenceTypeExpression: { auto Ref = static_cast(TE); auto Name = Ref->Name->getCanonicalText(); auto Match = Env.lookup(Name, SymbolKind::Type); if (Match == nullptr) { DE.add(Name, Ref->Name); Ty = createTVar(); } else { Ty = instantiate(Match); } break; } case NodeKind::ArrowTypeExpression: { auto Arrow = static_cast(TE); auto [ReturnOut, ReturnTy] = inferTypeExpr(Env, Arrow->ReturnType); Ty = ReturnTy; for (auto PT: Arrow->ParamTypes) { auto [ParamOut, ParamTy] = inferTypeExpr(Env, PT); mergeTo(Out, ParamOut); Ty = new TFun(ParamTy, Ty); } break; } default: ZEN_UNREACHABLE } TE->setType(Ty); return { Out, Ty }; } ConstraintSet Checker::inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaration* D) { auto TA = D->getTypeAssert(); auto Params = D->getParams(); auto Body = D->getBody(); ConstraintSet Out; TypeEnv NewEnv { Env }; auto RetTy = createTVar(); Type* Ty = RetTy; for (auto It = Params.end(); It-- != Params.begin(); ) { auto Param = *It; auto ParamTy = createTVar(); auto ParamOut = visitPattern(Param->Pattern, ParamTy, NewEnv); mergeTo(Out, ParamOut); Ty = new TFun(ParamTy, Ty); } if (TA != nullptr) { auto [TEOut, TETy] = inferTypeExpr(Env, TA->TypeExpression); mergeTo(Out, TEOut); Out.push_back(new CTypesEqual(Ty, TETy, TA->TypeExpression)); } if (Body != nullptr) { // TODO elminate BlockBody and replace with BlockExpr ZEN_ASSERT(Body->getKind() == NodeKind::LetExprBody); auto [BodyOut, BodyTy] = inferExpr(NewEnv, cast(Body)->Expression, RetTy); mergeTo(Out, BodyOut); Out.push_back(new CTypesEqual(RetTy, BodyTy, Body)); } // inferMany() will have set the type of the node to a fresh type variable. Out.push_back(new CTypesEqual { D->getType(), Ty, D }); return Out; } ConstraintSet Checker::inferVariableDeclaration(TypeEnv& Env, VariableDeclaration* Decl, Type* RetTy) { ConstraintSet Out; Type* Ty = nullptr; if (Decl->TypeAssert != nullptr) { auto [AssertOut, AssertTy] = inferTypeExpr(Env, Decl->TypeAssert->TypeExpression); mergeTo(Out, AssertOut); Ty = AssertTy; } if (Decl->Body != nullptr) { // TODO elminate BlockBody and replace with BlockExpr ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); auto [BodyOut, BodyTy] = inferExpr(Env, cast(Decl->Body)->Expression, RetTy); mergeTo(Out, BodyOut); if (Ty == nullptr) { Ty = BodyTy; } else { Out.push_back(new CTypesEqual(Ty, BodyTy, Decl->Body)); } } // Currently we don't perform generalisation on variable declarations Env.add(Decl->getNameAsString(), Ty, SymbolKind::Var); return Out; } bool hasTypeVar(Type* Ty, TVar* TV) { switch (TV->getKind()) { case TypeKind::App: { auto T = static_cast(Ty); return hasTypeVar(T->getLeft(), TV) || hasTypeVar(T->getRight(), TV); } case TypeKind::Con: return false; case TypeKind::Fun: { auto T = static_cast(Ty); return hasTypeVar(T->getLeft(), TV) || hasTypeVar(T->getRight(), TV); } case TypeKind::Var: { auto T = static_cast(Ty); return T->find() == TV; } } } bool TypeEnv::hasVar(TVar* TV) const { for (auto [_, Scm]: Mapping) { if (Scm->Unbound.count(TV)) { // FIXME ZEN_UNREACHABLE } if (hasTypeVar(Scm->getType(), TV)) { return true; } } return false; } auto getUnbound(const TypeEnv& Env, Type* Ty) { struct Visitor : public TypeVisitor { const TypeEnv& Env; Visitor(const TypeEnv& Env): Env(Env) {} std::vector Out; void visitVar(TVar* TV) { auto Solved = TV->find(); if (isa(Solved)) { auto Var = static_cast(Solved); if (!Env.hasVar(Var)) { Out.push_back(Var); } } else { visit(Solved); } } } V { Env }; V.visit(Ty); return V.Out; } ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector& Elements, Type* RetTy) { using Graph = zen::hash_graph; TypeEnv NewEnv { Env }; Graph G; std::function populate = [&](auto From, auto N) { struct Visitor : CSTVisitor { Graph& G; Node* From; void visitReferenceExpression(ReferenceExpression* E) { auto To = E->getScope()->lookup(E->getSymbolPath()); if (To) { if (isa(To)) { To = To->Parent; } if (isa(To) || isa(To)) { G.add_edge(From, To); } } } } V { {}, G, From }; V.visit(N); }; std::vector Stmts; for (auto Element: Elements) { if (isa(Element)) { auto Decl = static_cast(Element); G.add_vertex(Decl); if (Decl->hasBody()) { populate(Decl, Decl->getBody()); } } else if (isa(Element)) { auto Decl = static_cast(Element); G.add_vertex(Decl); if (Decl->hasExpression()) { populate(Decl, Decl->getExpression()); } } else { Stmts.push_back(Element); } } for (auto Mutual: zen::toposort(G)) { ConstraintSet Out; for (auto N: Mutual) { if (isa(N)) { auto Func = static_cast(N); Type* Ty = createTVar(); Func->setType(Ty); Env.add(Func->getNameAsString(), Ty, SymbolKind::Var); } } for (auto N: Mutual) { if (isa(N)) { mergeTo(Out, inferFunctionDeclaration(Env, static_cast(N))); } else if (isa(N)) { mergeTo(Out, inferVariableDeclaration(Env, static_cast(N), RetTy)); } else { ZEN_UNREACHABLE } } solve(Out); for (auto N: Mutual) { if (isa(N)) { auto Func = static_cast(N); auto Unbound = getUnbound(Env, Func->getType()); Env.add( Func->getNameAsString(), new TypeScheme { { Unbound.begin(), Unbound.end() }, Func->getType()->find() }, SymbolKind::Var ); } } } ConstraintSet Out; for (auto Stmt: Stmts) { mergeTo(Out, inferElement(Env, Stmt, RetTy)); } return Out; } ConstraintSet Checker::inferElement(TypeEnv& Env, Node* N, Type* RetTy) { if (isa(N)) { auto [Out, Ty] = inferExpr(Env, cast(N), RetTy); return Out; } switch (N->getKind()) { case NodeKind::PrefixFunctionDeclaration: case NodeKind::InfixFunctionDeclaration: case NodeKind::SuffixFunctionDeclaration: case NodeKind::NamedFunctionDeclaration: return inferFunctionDeclaration(Env, static_cast(N)); case NodeKind::ReturnExpression: { auto M = static_cast(N); if (!M->hasExpression()) { return {}; } auto [ValOut, ValTy] = inferExpr(Env, M->getExpression(), RetTy); return { new CTypesEqual(ValTy, RetTy, N) }; } default: ZEN_UNREACHABLE } } ConstraintSet Checker::inferSourceFile(TypeEnv& Env, SourceFile* SF) { return inferMany(Env, SF->Elements, nullptr); } ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected, Type* RetTy) { switch (Expr->getKind()) { case NodeKind::LiteralExpression: { auto E = static_cast(Expr); switch (E->Token->getKind()) { case NodeKind::IntegerLiteral: if (*Expected == *getIntType()) { return {}; } break; case NodeKind::StringLiteral: if (*Expected == *getStringType()) { return {}; } break; default: ZEN_UNREACHABLE; } goto fallback; } case NodeKind::FunctionExpression: { ConstraintSet Out; auto E = static_cast(Expr); // FIXME save RetTy on the node and re-use it in this function? if (Expected->getKind() == TypeKind::Fun) { TypeEnv NewEnv { Env }; TFun* Ty = static_cast(Expected); for (auto P: E->getParameters()) { auto ParamOut = visitPattern(P, Ty->getLeft(), NewEnv); mergeTo(Out, ParamOut); if (Ty->getRight()->getKind() != TypeKind::Fun) { goto fallback; } Ty = static_cast(Ty->getRight()); } auto ExprOut = checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight()); mergeTo(Out, ExprOut); return Out; } goto fallback; } default: { fallback: auto [Out, Actual] = inferExpr(Env, Expr, RetTy); Out.push_back(new CTypesEqual(Actual, Expected, Expr)); return Out; } } } void Checker::solve(const std::vector& Constraints) { for (auto C: Constraints) { switch (C->getKind()) { case ConstraintKind::TypesEqual: { auto D = static_cast(C); unifyTypeType(D->getLeft(), D->getRight(), D->getOrigin()); break; } } } } void Checker::unifyTypeType(Type* A, Type* B, Node* N) { A = A->find(); B = B->find(); if (A->getKind() == TypeKind::Var) { auto TV = static_cast(A); // TODO occurs check TV->set(B); return; } if (B->getKind() == TypeKind::Var) { unifyTypeType(B, A, N); return; } if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) { auto C1 = static_cast(A); auto C2 = static_cast(B); if (C1->getName() == C2->getName()) { return; } } if (A->getKind() == TypeKind::Fun && B->getKind() == TypeKind::Fun) { auto F1 = static_cast(A); auto F2 = static_cast(B); unifyTypeType(F1->getLeft(), F2->getLeft(), N); unifyTypeType(F1->getRight(), F2->getRight(), N); return; } DE.add(A, B, N); } void Checker::run(SourceFile* SF) { TypeEnv Env; auto A = createTVar(); auto B = createTVar(); auto Bool = getBoolType(); auto Int = getIntType(); auto String = getStringType(); Env.add("Int", Int, SymbolKind::Type); Env.add("Bool", Bool, SymbolKind::Type); Env.add("String", String, SymbolKind::Type); Env.add("True", Bool, SymbolKind::Var); Env.add("False", Bool, SymbolKind::Var); Env.add("not", new TFun(Bool, Bool), SymbolKind::Var); Env.add("+", new TFun(Int, new TFun(Int, Int)), SymbolKind::Var); Env.add("-", new TFun(Int, new TFun(Int, Int)), SymbolKind::Var); Env.add("$", new TypeScheme({ A, B }, new TFun(new TFun(A, B), new TFun(A, B))), SymbolKind::Var); auto Out = inferSourceFile(Env, SF); solve(Out); } Type* resolveType(Type* Ty) { switch (Ty->getKind()) { case TypeKind::App: { auto A = static_cast(Ty); auto NewLeft = resolveType(A->getLeft()); auto NewRight = resolveType(A->getRight()); if (A->getLeft() == NewLeft && A->getRight() == NewRight) { return Ty; } return new TApp(NewLeft, NewRight); } case TypeKind::Con: return Ty; case TypeKind::Var: { auto NewTy = Ty->find(); if (NewTy->getKind() != TypeKind::Var) { return resolveType(NewTy); } else { return NewTy; } } case TypeKind::Fun: { auto F = static_cast(Ty); auto NewLeft = resolveType(F->getLeft()); auto NewRight = resolveType(F->getRight()); if (F->getLeft() == NewLeft && F->getRight() == NewRight) { return Ty; } return new TFun(NewLeft, NewRight); } } } Type* Checker::getTypeOfNode(Node* N) { auto M = cast(N); return resolveType(M->getType()); } }