// TODO Add list of CST variable names to TVar and unify them so that e.g. the typeclass checker may pick one when displaying a diagnostic // TODO (maybe) make unficiation work like union-find in find() // TODO remove Args in TCon and just use it as a constant // TODO make TApp traversable with TupleIndex // TODO make simplify() rewrite the types in-place such that a reference too (Bool, Int).0 becomes Bool // TODO Add a check for datatypes that create infinite structures. // TODO see if we can merge UnificationError diagnostics so that we get a list of **all** types that were wrong on a given node // TODO When a forall variable is missing, do not just insert a blank one into the env. It will result in too few diagnostics being emitted. // Same goes for reference expressions. // If running the compiler as a language server, this matters. // TODO Add a pattern that only performs a type assert // TODO create the constraint in addConstraint, not the other way round #include #include #include #include #include "llvm/Support/Casting.h" #include "bolt/Type.hpp" #include "zen/config.hpp" #include "zen/range.hpp" #include "bolt/CSTVisitor.hpp" #include "bolt/DiagnosticEngine.hpp" #include "bolt/Diagnostics.hpp" #include "bolt/CST.hpp" #include "bolt/Checker.hpp" namespace bolt { std::string describe(const Type* Ty); Constraint* Constraint::substitute(const TVSub &Sub) { switch (Kind) { case ConstraintKind::Class: { auto Class = static_cast(this); std::vector NewTypes; for (auto Ty: Class->Types) { NewTypes.push_back(Ty->substitute(Sub)); } return new CClass(Class->Name, NewTypes); } case ConstraintKind::Equal: { auto Equal = static_cast(this); return new CEqual(Equal->Left->substitute(Sub), Equal->Right->substitute(Sub), Equal->Source); } case ConstraintKind::Many: { auto Many = static_cast(this); auto NewConstraints = new ConstraintSet(); for (auto Element: Many->Elements) { NewConstraints->push_back(Element->substitute(Sub)); } return new CMany(*NewConstraints); } case ConstraintKind::Empty: return this; } } Type* Checker::simplifyType(Type* Ty) { return Ty->rewrite([&](auto Ty) { if (Ty->getKind() == TypeKind::Var) { Ty = static_cast(Ty)->find(); } if (Ty->getKind() == TypeKind::TupleIndex) { auto Index = static_cast(Ty); auto MaybeTuple = simplifyType(Index->Ty); if (MaybeTuple->getKind() == TypeKind::Tuple) { auto Tuple = static_cast(MaybeTuple); if (Index->I >= Tuple->ElementTypes.size()) { DE.add(Tuple, Index->I); } else { Ty = simplifyType(Tuple->ElementTypes[Index->I]); } } } return Ty; }, /*Recursive=*/true); } Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): Config(Config), DE(DE) { BoolType = createConType("Bool"); IntType = createConType("Int"); StringType = createConType("String"); ListType = createConType("List"); } Scheme* Checker::lookup(ByteString Name) { auto Curr = &getContext(); for (;;) { auto Match = Curr->Env.find(Name); if (Match != Curr->Env.end()) { return Match->second; } Curr = Curr->Parent; if (!Curr) { break; } } return nullptr; } Type* Checker::lookupMono(ByteString Name) { auto Scm = lookup(Name); if (Scm == nullptr) { return nullptr; } auto F = static_cast(Scm); ZEN_ASSERT(F->TVs == nullptr || F->TVs->empty()); return F->Type; } void Checker::addBinding(ByteString Name, Scheme* Scm) { getContext().Env.emplace(Name, Scm); } Type* Checker::getReturnType() { auto Ty = getContext().ReturnType; ZEN_ASSERT(Ty != nullptr); return Ty; } static bool hasTypeVar(TVSet& Set, Type* Type) { for (auto TV: Type->getTypeVars()) { if (Set.count(TV)) { return true; } } return false; } void Checker::setContext(InferContext* Ctx) { ActiveContext = Ctx; } void Checker::popContext() { ZEN_ASSERT(ActiveContext); ActiveContext = ActiveContext->Parent; } InferContext& Checker::getContext() { ZEN_ASSERT(ActiveContext); return *ActiveContext; } void Checker::addConstraint(Constraint* C) { switch (C->getKind()) { case ConstraintKind::Class: { getContext().Constraints->push_back(C); break; } case ConstraintKind::Equal: { auto Y = static_cast(C); // This will store all inference contexts in Contexts, from most local // one to most general one. Because this order is not ideal, the code // below will have to handle that. auto Curr = &getContext(); std::vector Contexts; for (;;) { Contexts.push_back(Curr); Curr = Curr->Parent; if (!Curr) { break; } } // If no MaxLevelLeft was found, that means that not a single // corresponding type variable was found in the contexts. We set it to // 0, which corresponds to the global inference context. std::size_t MaxLevelLeft = 0; for (std::size_t I = 0; I < Contexts.size(); I++) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Left)) { MaxLevelLeft = Contexts.size() - I - 1; break; } } // Same as above but now mirrored for Y->Right std::size_t MaxLevelRight = 0; for (std::size_t I = 0; I < Contexts.size(); I++) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Right)) { MaxLevelRight = Contexts.size() - I - 1; break; } } // The lowest index is determined by the one that has no type variables // in Y->Left AND in Y->Right. This implies max() must be used, so that // the very first enounter of a type variable matters. auto MaxLevel = std::max(MaxLevelLeft, MaxLevelRight); // Now find the highest index I such that all the contexts that are more // local do not contain any type variables that are present in the // equality constraint. std::size_t MinLevel = MaxLevel; for (std::size_t I = Contexts.size(); I-- > 0; ) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { // No need to reverse because even though Contexts is reversed, we // are also iterating in reverse. MinLevel = I; break; } } if (MaxLevel == MinLevel || MaxLevelLeft == 0 || MaxLevelRight == 0) { solveEqual(Y); } else { Contexts[Contexts.size() - MaxLevel - 1]->Constraints->push_back(C); } break; } case ConstraintKind::Many: { auto Y = static_cast(C); for (auto Element: Y->Elements) { addConstraint(Element); } break; } case ConstraintKind::Empty: break; } } void Checker::forwardDeclare(Node* X) { switch (X->getKind()) { case NodeKind::ExpressionStatement: case NodeKind::ReturnStatement: case NodeKind::IfStatement: break; case NodeKind::SourceFile: { auto File = static_cast(X); for (auto Element: File->Elements) { forwardDeclare(Element) ; } break; } case NodeKind::ClassDeclaration: { auto Class = static_cast(X); for (auto TE: Class->TypeVars) { auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText()); TV->Contexts.emplace(Class->Name->getCanonicalText()); TE->setType(TV); } for (auto Element: Class->Elements) { forwardDeclare(Element); } break; } case NodeKind::InstanceDeclaration: { auto Decl = static_cast(X); // Needed to set the associated Type on the CST node for (auto TE: Decl->TypeExps) { inferTypeExpression(TE); } auto Match = InstanceMap.find(Decl->Name->getCanonicalText()); if (Match == InstanceMap.end()) { InstanceMap.emplace(Decl->Name->getCanonicalText(), std::vector { Decl }); } else { Match->second.push_back(Decl); } for (auto Element: Decl->Elements) { forwardDeclare(Element); } break; } case NodeKind::FunctionDeclaration: // These declarations will be handled separately in check() break; case NodeKind::VariableDeclaration: // All of this node's semantics will be handled in infer() break; case NodeKind::VariantDeclaration: { auto Decl = static_cast(X); setContext(Decl->Ctx); std::vector Vars; for (auto TE: Decl->TVs) { auto TV = createRigidVar(TE->Name->getCanonicalText()); Decl->Ctx->TVs->emplace(TV); Vars.push_back(TV); } Type* Ty = createConType(Decl->Name->getCanonicalText()); // Must be added early so we can create recursive types Decl->Ctx->Parent->Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty)); for (auto Member: Decl->Members) { switch (Member->getKind()) { case NodeKind::TupleVariantDeclarationMember: { auto TupleMember = static_cast(Member); auto RetTy = Ty; for (auto Var: Vars) { RetTy = new TApp(RetTy, Var); } std::vector ParamTypes; for (auto Element: TupleMember->Elements) { ParamTypes.push_back(inferTypeExpression(Element)); } Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(ParamTypes, RetTy))); break; } case NodeKind::RecordVariantDeclarationMember: { // TODO break; } default: ZEN_UNREACHABLE } } popContext(); break; } case NodeKind::RecordDeclaration: { auto Decl = static_cast(X); setContext(Decl->Ctx); std::vector Vars; for (auto TE: Decl->Vars) { auto TV = createRigidVar(TE->Name->getCanonicalText()); Vars.push_back(TV); } auto Name = Decl->Name->getCanonicalText(); auto Ty = createConType(Name); // Must be added early so we can create recursive types Decl->Ctx->Parent->Env.emplace(Name, new Forall(Ty)); // Corresponds to the logic of one branch of a VariantDeclarationMember Type* FieldsTy = new TNil(); for (auto Field: Decl->Fields) { FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy); } Type* RetTy = Ty; for (auto TV: Vars) { RetTy = new TApp(RetTy, TV); } Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy))); popContext(); break; } default: ZEN_UNREACHABLE } } void Checker::initialize(Node* N) { struct Init : public CSTVisitor { Checker& C; std::stack Contexts; InferContext* createDerivedContext() { return C.createInferContext(Contexts.top()); } void visitVariantDeclaration(VariantDeclaration* Decl) { Decl->Ctx = createDerivedContext(); } void visitRecordDeclaration(RecordDeclaration* Decl) { Decl->Ctx = createDerivedContext(); } void visitMatchCase(MatchCase* C) { C->Ctx = createDerivedContext(); Contexts.push(C->Ctx); visitEachChild(C); Contexts.pop(); } void visitSourceFile(SourceFile* SF) { SF->Ctx = C.createInferContext(); Contexts.push(SF->Ctx); visitEachChild(SF); Contexts.pop(); } void visitFunctionDeclaration(FunctionDeclaration* Let) { Let->Ctx = createDerivedContext(); Contexts.push(Let->Ctx); visitEachChild(Let); Contexts.pop(); } // void visitVariableDeclaration(VariableDeclaration* Var) { // Var->Ctx = Contexts.top(); // visitEachChild(Var); // } }; Init I { {}, *this }; I.visit(N); } void Checker::forwardDeclareFunctionDeclaration(FunctionDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { setContext(Let->Ctx); // If declaring a let-declaration inside a type class declaration, // we need to mark that the let-declaration requires this class. // This marking is set on the rigid type variables of the class, which // are then added to this local type environment. if (Let->isClass()) { auto Class = static_cast(Let->Parent); for (auto TE: Class->TypeVars) { auto TV = llvm::cast(TE->getType()); Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV)); Let->Ctx->TVs->emplace(TV); } } // Here we infer the primary type of the let declaration. If there's a // type assert, that assert should be authoritative so we use that. // Otherwise, the type is not further specified and we create a new // unification variable. Type* Ty; if (Let->TypeAssert) { Ty = inferTypeExpression(Let->TypeAssert->TypeExpression); } else { Ty = createTypeVar(); } Let->Ty = Ty; // If declaring a let-declaration inside a type instance declaration, // we need to perform some work to make sure the type asserts of the // corresponding let-declaration in the type class declaration are // accounted for. if (Let->isInstance()) { auto Instance = static_cast(Let->Parent); auto Class = llvm::cast(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class)); // The type asserts in the type class declaration might make use of // the type parameters of the type class declaration, so it is // important to make them available in the type environment. Moreover, // we will be unifying them with the actual types declared in the // instance declaration, so we keep track of them. std::vector Params; TVSub Sub; for (auto TE: Class->TypeVars) { auto TV = createTypeVar(); Sub.emplace(llvm::cast(TE->getType()), TV); Params.push_back(TV); } auto SigLet = llvm::cast(Class->getScope()->lookupDirect({ {}, Let->Name->getCanonicalText() }, SymbolKind::Var)); // It would be very strange if there was no type assert in the type // class let-declaration but we rather not let the compiler crash if that happens. if (SigLet->TypeAssert) { addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression)->substitute(Sub), Let)); } // Here we do the actual unification of e.g. Eq a with Eq Bool. The // unification variables we created previously will be unified with // e.g. Bool, which causes the type assert to also collapse to e.g. // Bool -> Bool -> Bool. for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { addConstraint(new CEqual(Param, TE->getType(), TE)); } } if (Let->Body) { switch (Let->Body->getKind()) { case NodeKind::LetExprBody: break; case NodeKind::LetBlockBody: { auto Block = static_cast(Let->Body); Let->Ctx->ReturnType = createTypeVar(); for (auto Element: Block->Elements) { forwardDeclare(Element); } break; } default: ZEN_UNREACHABLE } } Let->Ctx->Parent->Env.emplace(Let->Name->getCanonicalText(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty)); } void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) { setContext(Decl->Ctx); std::vector ParamTypes; Type* RetType; for (auto Param: Decl->Params) { ParamTypes.push_back(inferPattern(Param->Pattern)); } if (Decl->Body) { switch (Decl->Body->getKind()) { case NodeKind::LetExprBody: { auto Expr = static_cast(Decl->Body); RetType = inferExpression(Expr->Expression); break; } case NodeKind::LetBlockBody: { auto Block = static_cast(Decl->Body); RetType = Decl->Ctx->ReturnType; for (auto Element: Block->Elements) { infer(Element); } break; } default: ZEN_UNREACHABLE } } else { RetType = createTypeVar(); } addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl }); } void Checker::infer(Node* N) { switch (N->getKind()) { case NodeKind::SourceFile: { auto File = static_cast(N); for (auto Element: File->Elements) { infer(Element); } break; } case NodeKind::ClassDeclaration: { auto Decl = static_cast(N); for (auto Element: Decl->Elements) { infer(Element); } break; } case NodeKind::InstanceDeclaration: { auto Decl = static_cast(N); for (auto Element: Decl->Elements) { infer(Element); } break; } case NodeKind::VariantDeclaration: case NodeKind::RecordDeclaration: // Nothing to do for a type-level declaration break; case NodeKind::IfStatement: { auto IfStmt = static_cast(N); for (auto Part: IfStmt->Parts) { if (Part->Test != nullptr) { addConstraint(new CEqual { BoolType, inferExpression(Part->Test), Part->Test }); } for (auto Element: Part->Elements) { infer(Element); } } break; } case NodeKind::FunctionDeclaration: break; case NodeKind::ReturnStatement: { auto RetStmt = static_cast(N); Type* ReturnType; if (RetStmt->Expression) { addConstraint(new CEqual { inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression }); } else { ReturnType = new TTuple({}); addConstraint(new CEqual { new TTuple({}), getReturnType(), N }); } break; } case NodeKind::VariableDeclaration: { auto Decl = static_cast(N); Type* Ty = nullptr; if (Decl->TypeAssert) { Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression, false); } if (Decl->Body) { ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); auto E = static_cast(Decl->Body); auto Ty2 = inferExpression(E->Expression); if (Ty) { addConstraint(new CEqual(Ty, Ty2, Decl)); } else { Ty = Ty2; } } auto Ty3 = inferPattern(Decl->Pattern); if (Ty) { addConstraint(new CEqual(Ty, Ty3, Decl)); } else { Ty = Ty3; } Decl->setType(Ty); break; } case NodeKind::ExpressionStatement: { auto ExprStmt = static_cast(N); inferExpression(ExprStmt->Expression); break; } default: ZEN_UNREACHABLE } } TCon* Checker::createConType(ByteString Name) { return new TCon(NextConTypeId++, Name); } TVarRigid* Checker::createRigidVar(ByteString Name) { auto TV = new TVarRigid(NextTypeVarId++, Name); getContext().TVs->emplace(TV); return TV; } TVar* Checker::createTypeVar() { auto TV = new TVar(NextTypeVarId++, VarKind::Unification); getContext().TVs->emplace(TV); return TV; } InferContext* Checker::createInferContext(InferContext* Parent, TVSet* TVs, ConstraintSet* Constraints) { auto Ctx = new InferContext; Ctx->Parent = Parent; Ctx->TVs = new TVSet; Ctx->Constraints = new ConstraintSet; return Ctx; } Type* Checker::instantiate(Scheme* Scm, Node* Source) { switch (Scm->getKind()) { case SchemeKind::Forall: { auto F = static_cast(Scm); TVSub Sub; for (auto TV: *F->TVs) { auto Fresh = createTypeVar(); Fresh->Contexts = TV->Contexts; Sub[TV] = Fresh; } for (auto Constraint: *F->Constraints) { // FIXME improve this if (Constraint->getKind() == ConstraintKind::Equal) { auto Eq = static_cast(Constraint); Eq->Left = simplifyType(Eq->Left); Eq->Right = simplifyType(Eq->Right); } auto NewConstraint = Constraint->substitute(Sub); // This makes error messages prettier by relating the typing failure // to the call site rather than the definition. if (NewConstraint->getKind() == ConstraintKind::Equal) { auto Eq = static_cast(Constraint); Eq->Source = Source; } addConstraint(NewConstraint); } // Note the call to simplify? This is because constraints may have already // been solved, with some unification variables being erased. To make // sure we instantiate unification variables that are still in use // we solve before substituting. return simplifyType(F->Type)->substitute(Sub); } } } Constraint* Checker::convertToConstraint(ConstraintExpression* C) { switch (C->getKind()) { case NodeKind::TypeclassConstraintExpression: { auto D = static_cast(C); std::vector Types; for (auto TE: D->TEs) { Types.push_back(inferTypeExpression(TE)); } return new CClass(D->Name->getCanonicalText(), Types); } case NodeKind::EqualityConstraintExpression: { auto D = static_cast(C); return new CEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C); } default: ZEN_UNREACHABLE } } Type* Checker::inferTypeExpression(TypeExpression* N, bool IsPoly) { switch (N->getKind()) { case NodeKind::ReferenceTypeExpression: { auto RefTE = static_cast(N); auto Scm = lookup(RefTE->Name->getCanonicalText()); Type* Ty; if (Scm == nullptr) { DE.add(RefTE->Name->getCanonicalText(), RefTE->Name); Ty = createTypeVar(); } else { Ty = instantiate(Scm, RefTE); } N->setType(Ty); return Ty; } case NodeKind::AppTypeExpression: { auto AppTE = static_cast(N); Type* Ty = inferTypeExpression(AppTE->Op, IsPoly); for (auto Arg: AppTE->Args) { Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly)); } return Ty; } case NodeKind::VarTypeExpression: { auto VarTE = static_cast(N); auto Ty = lookupMono(VarTE->Name->getCanonicalText()); if (Ty == nullptr) { if (IsPoly && Config.typeVarsRequireForall()) { DE.add(VarTE->Name->getCanonicalText(), VarTE->Name); } Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar(); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); } ZEN_ASSERT(Ty->getKind() == TypeKind::Var); N->setType(Ty); return static_cast(Ty); } case NodeKind::TupleTypeExpression: { auto TupleTE = static_cast(N); std::vector ElementTypes; for (auto [TE, Comma]: TupleTE->Elements) { ElementTypes.push_back(inferTypeExpression(TE, IsPoly)); } auto Ty = new TTuple(ElementTypes); N->setType(Ty); return Ty; } case NodeKind::NestedTypeExpression: { auto NestedTE = static_cast(N); auto Ty = inferTypeExpression(NestedTE->TE, IsPoly); N->setType(Ty); return Ty; } case NodeKind::ArrowTypeExpression: { auto ArrowTE = static_cast(N); std::vector ParamTypes; for (auto ParamType: ArrowTE->ParamTypes) { ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly)); } auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly); auto Ty = new TArrow(ParamTypes, ReturnType); N->setType(Ty); return Ty; } case NodeKind::QualifiedTypeExpression: { auto QTE = static_cast(N); for (auto [C, Comma]: QTE->Constraints) { addConstraint(convertToConstraint(C)); } auto Ty = inferTypeExpression(QTE->TE, IsPoly); N->setType(Ty); return Ty; } default: ZEN_UNREACHABLE } } Type* sortRow(Type* Ty) { std::map Fields; while (Ty->getKind() == TypeKind::Field) { auto Field = static_cast(Ty); Fields.emplace(Field->Name, Field); Ty = Field->RestTy; } for (auto [Name, Field]: Fields) { Ty = new TField(Name, Field->Ty, Ty); } return Ty; } Type* Checker::inferExpression(Expression* X) { Type* Ty; switch (X->getKind()) { case NodeKind::MatchExpression: { auto Match = static_cast(X); Type* ValTy; if (Match->Value) { ValTy = inferExpression(Match->Value); } else { ValTy = createTypeVar(); } Ty = createTypeVar(); for (auto Case: Match->Cases) { auto OldCtx = &getContext(); setContext(Case->Ctx); auto PattTy = inferPattern(Case->Pattern); addConstraint(new CEqual(PattTy, ValTy, Case)); auto ExprTy = inferExpression(Case->Expression); addConstraint(new CEqual(ExprTy, Ty, Case->Expression)); setContext(OldCtx); } if (!Match->Value) { Ty = new TArrow({ ValTy }, Ty); } break; } case NodeKind::RecordExpression: { auto Record = static_cast(X); Ty = new TNil(); for (auto [Field, Comma]: Record->Fields) { Ty = new TField(Field->Name->getCanonicalText(), new TPresent(inferExpression(Field->getExpression())), Ty); } Ty = sortRow(Ty); break; } case NodeKind::ConstantExpression: { auto Const = static_cast(X); Ty = inferLiteral(Const->Token); break; } case NodeKind::ReferenceExpression: { auto Ref = static_cast(X); ZEN_ASSERT(Ref->ModulePath.empty()); auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); if (Target && llvm::isa(Target)) { auto Let = static_cast(Target); if (Let->IsCycleActive) { return Let->Ty; } } auto Scm = lookup(Ref->Name->getCanonicalText()); if (Scm == nullptr) { DE.add(Ref->Name->getCanonicalText(), Ref->Name); return createTypeVar(); } Ty = instantiate(Scm, X); break; } case NodeKind::CallExpression: { auto Call = static_cast(X); auto OpTy = inferExpression(Call->Function); Ty = createTypeVar(); std::vector ArgTypes; for (auto Arg: Call->Args) { ArgTypes.push_back(inferExpression(Arg)); } addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, Ty), X }); break; } case NodeKind::InfixExpression: { auto Infix = static_cast(X); auto Scm = lookup(Infix->Operator->getText()); if (Scm == nullptr) { DE.add(Infix->Operator->getText(), Infix->Operator); return createTypeVar(); } auto OpTy = instantiate(Scm, Infix->Operator); Ty = createTypeVar(); std::vector ArgTys; ArgTys.push_back(inferExpression(Infix->LHS)); ArgTys.push_back(inferExpression(Infix->RHS)); addConstraint(new CEqual { new TArrow(ArgTys, Ty), OpTy, X }); break; } case NodeKind::TupleExpression: { auto Tuple = static_cast(X); std::vector Types; for (auto [E, Comma]: Tuple->Elements) { Types.push_back(inferExpression(E)); } Ty = new TTuple(Types); break; } case NodeKind::MemberExpression: { auto Member = static_cast(X); auto ExprTy = inferExpression(Member->E); switch (Member->Name->getKind()) { case NodeKind::IntegerLiteral: { auto I = static_cast(Member->Name); Ty = new TTupleIndex(ExprTy, I->getInteger()); break; } case NodeKind::Identifier: { auto K = static_cast(Member->Name); Ty = createTypeVar(); auto RestTy = createTypeVar(); addConstraint(new CEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member)); break; } default: ZEN_UNREACHABLE } break; } case NodeKind::NestedExpression: { auto Nested = static_cast(X); Ty = inferExpression(Nested->Inner); break; } default: ZEN_UNREACHABLE } // Ty = find(Ty); X->setType(Ty); return Ty; } Type* Checker::inferPattern( Pattern* Pattern, ConstraintSet* Constraints, TVSet* TVs ) { switch (Pattern->getKind()) { case NodeKind::BindPattern: { auto P = static_cast(Pattern); auto Ty = createTypeVar(); addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty)); return Ty; } case NodeKind::NamedPattern: { auto P = static_cast(Pattern); auto Scm = lookup(P->Name->getCanonicalText()); std::vector ParamTypes; for (auto P2: P->Patterns) { ParamTypes.push_back(inferPattern(P2, Constraints, TVs)); } if (!Scm) { DE.add(P->Name->getCanonicalText(), P->Name); return createTypeVar(); } auto Ty = instantiate(Scm, P); auto RetTy = createTypeVar(); addConstraint(new CEqual(Ty, new TArrow(ParamTypes, RetTy), P)); return RetTy; } case NodeKind::TuplePattern: { auto P = static_cast(Pattern); std::vector ElementTypes; for (auto [Element, Comma]: P->Elements) { ElementTypes.push_back(inferPattern(Element)); } return new TTuple(ElementTypes); } case NodeKind::ListPattern: { auto P = static_cast(Pattern); auto ElementType = createTypeVar(); for (auto [Element, Separator]: P->Elements) { addConstraint(new CEqual(ElementType, inferPattern(Element), P)); } return new TApp(ListType, ElementType); } case NodeKind::NestedPattern: { auto P = static_cast(Pattern); return inferPattern(P->P, Constraints, TVs); } case NodeKind::LiteralPattern: { auto P = static_cast(Pattern); return inferLiteral(P->Literal); } default: ZEN_UNREACHABLE } } Type* Checker::inferLiteral(Literal* L) { Type* Ty; switch (L->getKind()) { case NodeKind::IntegerLiteral: Ty = lookupMono("Int"); break; case NodeKind::StringLiteral: Ty = lookupMono("String"); break; default: ZEN_UNREACHABLE } ZEN_ASSERT(Ty != nullptr); return Ty; } void Checker::populate(SourceFile* SF) { struct Visitor : public CSTVisitor { Graph& RefGraph; std::stack Stack; void visitFunctionDeclaration(FunctionDeclaration* N) { RefGraph.addVertex(N); Stack.push(N); visitEachChild(N); Stack.pop(); } void visitReferenceExpression(ReferenceExpression* N) { auto Y = static_cast(N); auto Def = Y->getScope()->lookup(Y->getSymbolPath()); // Name lookup failures will be reported directly in inferExpression(). if (Def == nullptr || Def->getKind() == NodeKind::SourceFile) { return; } // This case ensures that a deeply nested structure that references a // parameter of a parent node but is not referenced itself is correctly handled. // Note that the edge goes from the parent let to the parameter. This is normal. if (Def->getKind() == NodeKind::Parameter) { RefGraph.addEdge(Stack.top(), Def->Parent); return; } ZEN_ASSERT(Def->getKind() == NodeKind::FunctionDeclaration || Def->getKind() == NodeKind::VariableDeclaration); if (!Stack.empty()) { RefGraph.addEdge(Def, Stack.top()); } } }; Visitor V { {}, RefGraph }; V.visit(SF); } void Checker::checkTypeclassSigs(Node* N) { struct LetVisitor : CSTVisitor { Checker& C; void visitLetDeclaration(FunctionDeclaration* Decl) { // Only inspect those let-declarations that look like a function if (Decl->Params.empty()) { return; } // Will contain the type classes that were specified in the type assertion by the user. // There might be some other signatures as well, but those are an implementation detail. std::vector Expected; // We must add the type class itself to Expected because in order for // propagation to work the rigid type variables expect this class to be // present even inside the current class. By adding it to Expected, we // are effectively cancelling out the default behavior of requiring the // presence of this type classes. if (llvm::isa(Decl->Parent)) { auto Class = llvm::cast(Decl->Parent); std::vector Tys; for (auto TE : Class->TypeVars) { Tys.push_back(llvm::cast(TE->getType())); } Expected.push_back( TypeclassSignature{Class->Name->getCanonicalText(), Tys}); } // Here we scan the type signature for type classes that user expects to be there. if (Decl->TypeAssert != nullptr) { if (llvm::isa(Decl->TypeAssert->TypeExpression)) { auto QTE = static_cast(Decl->TypeAssert->TypeExpression); for (auto [C, Comma]: QTE->Constraints) { if (llvm::isa(C)) { auto TCE = static_cast(C); std::vector Tys; for (auto TE: TCE->TEs) { auto TV = TE->getType(); ZEN_ASSERT(llvm::isa(TV)); Tys.push_back(static_cast(TV)); } Expected.push_back(TypeclassSignature { TCE->Name->getCanonicalText(), Tys }); } } } } // Sort them lexically and remove any duplicates std::sort(Expected.begin(), Expected.end()); Expected.erase(std::unique(Expected.begin(), Expected.end()), Expected.end()); // Will contain the type class signatures that our program inferred that // at the very least should be present to make the body work. std::vector Actual; // This is ugly but it works. Scan all type variables local to this // declaration and add the classes that they require to Actual. for (auto Ty: *Decl->Ctx->TVs) { auto S = Ty->solve(); if (llvm::isa(S)) { auto TV = static_cast(S); for (auto Class: TV->Contexts) { Actual.push_back(TypeclassSignature { Class, { TV } }); } } } // Sort them lexically and remove any duplicates std::sort(Actual.begin(), Actual.end()); Actual.erase(std::unique(Actual.begin(), Actual.end()), Actual.end()); auto ActualIter = Actual.begin(); auto ExpectedIter = Expected.begin(); for (; ActualIter != Actual.end() || ExpectedIter != Expected.end() ;) { // Our program inferred no more type classes that should be present, // yet Expected still did find a few that the user declared in a // signature. No errors should be reported, and we can quit this loop. if (ActualIter == Actual.end()) { // TODO Maybe issue a warning that a type class went unused break; } // There are no more type classes that were expected, so any remaining // type classes in Actual will not have a corresponding signature. // This should be reported as an error. if (ExpectedIter == Expected.end()) { for (; ActualIter != Actual.end(); ActualIter++) { C.DE.add(*ActualIter, Decl); } break; } // If ExpectedIter is already at Show, but ActualIter is still at Eq, // then we clearly missed the Eq in ExpectedIter. This clearly is an // error, since the user missed something in a type signature. if (*ActualIter < *ExpectedIter) { C.DE.add(*ActualIter, Decl); ActualIter++; continue; } // If ActualIter is Show but ExpectedIter is still Eq, then the user // specified too much type classes in a type signature. This is no error, // but it might be worthwhile to issue a warning. if (*ExpectedIter < *ActualIter) { // DE.add(It2->Name, Decl); ExpectedIter++; continue; } // Both type class signatures are equal, cancelling each other out. ActualIter++; ExpectedIter++; } } }; LetVisitor V { {}, *this }; V.visit(N); } Type* Checker::getType(TypedNode *Node) { return Node->getType()->solve(); } void Checker::check(SourceFile *SF) { initialize(SF); setContext(SF->Ctx); addBinding("String", new Forall(StringType)); addBinding("Int", new Forall(IntType)); addBinding("Bool", new Forall(BoolType)); addBinding("List", new Forall(ListType)); addBinding("True", new Forall(BoolType)); addBinding("False", new Forall(BoolType)); auto A = createTypeVar(); addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, new TArrow({ A, A }, BoolType))); addBinding("+", new Forall(new TArrow({ IntType, IntType }, IntType))); addBinding("-", new Forall(new TArrow({ IntType, IntType }, IntType))); addBinding("*", new Forall(new TArrow({ IntType, IntType }, IntType))); addBinding("/", new Forall(new TArrow({ IntType, IntType }, IntType))); populate(SF); forwardDeclare(SF); auto SCCs = RefGraph.strongconnect(); for (auto Nodes: SCCs) { auto TVs = new TVSet; auto Constraints = new ConstraintSet; for (auto N: Nodes) { if (N->getKind() != NodeKind::FunctionDeclaration) { continue; } auto Decl = static_cast(N); forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); } } for (auto Nodes: SCCs) { for (auto N: Nodes) { if (N->getKind() != NodeKind::FunctionDeclaration) { continue; } auto Decl = static_cast(N); Decl->IsCycleActive = true; } for (auto N: Nodes) { if (N->getKind() != NodeKind::FunctionDeclaration) { continue; } auto Decl = static_cast(N); inferFunctionDeclaration(Decl); } for (auto N: Nodes) { if (N->getKind() != NodeKind::FunctionDeclaration) { continue; } auto Decl = static_cast(N); Decl->IsCycleActive = false; } } setContext(SF->Ctx); infer(SF); // Important because otherwise some logic for some optimisations will kick in that are no longer active. ActiveContext = nullptr; solve(new CMany(*SF->Ctx->Constraints)); checkTypeclassSigs(SF); } void Checker::solve(Constraint* Constraint) { Queue.push_back(Constraint); while (!Queue.empty()) { auto Constraint = Queue.front(); Queue.pop_front(); switch (Constraint->getKind()) { case ConstraintKind::Class: { // TODO break; } case ConstraintKind::Empty: break; case ConstraintKind::Many: { auto Many = static_cast(Constraint); for (auto Constraint: Many->Elements) { Queue.push_back(Constraint); } break; } case ConstraintKind::Equal: { solveEqual(static_cast(Constraint)); break; } } } } bool assignableTo(Type* A, Type* B) { if (llvm::isa(A) && llvm::isa(B)) { auto Con1 = llvm::cast(A); auto Con2 = llvm::cast(B); if (Con1->Id != Con2-> Id) { return false; } // TODO must handle a TApp // ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); // for (auto [T1, T2]: zen::zip(Con1->Args, Con2->Args)) { // if (!assignableTo(T1, T2)) { // return false; // } // } return true; } ZEN_UNREACHABLE } class ArrowCursor { std::stack> Stack; TypePath& Path; std::size_t I; public: ArrowCursor(TArrow* Arr, TypePath& Path): Path(Path) { Stack.push({ Arr, true }); Path.push_back(Arr->getStartIndex()); } Type* next() { while (!Stack.empty()) { auto& [Arrow, First] = Stack.top(); auto& Index = Path.back(); if (!First) { Index.advance(Arrow); } else { First = false; } Type* Ty; if (Index == Arrow->getEndIndex()) { Path.pop_back(); Stack.pop(); continue; } Ty = Arrow->resolve(Index); if (llvm::isa(Ty)) { auto NewIndex = Arrow->getStartIndex(); Stack.push({ static_cast(Ty), true }); Path.push_back(NewIndex); } else { return Ty; } } return nullptr; } }; struct Unifier { Checker& C; CEqual* Constraint; // Internal state used by the unifier ByteString CurrentFieldName; TypePath LeftPath; TypePath RightPath; Type* getLeft() const { return Constraint->Left; } Type* getRight() const { return Constraint->Right; } Node* getSource() const { return Constraint->Source; } bool unify(Type* A, Type* B, bool DidSwap); bool unifyField(Type* A, Type* B); bool unify() { return unify(Constraint->Left, Constraint->Right, false); } std::vector findInstanceContext(TCon* Ty, TypeclassId& Class) { auto Match = C.InstanceMap.find(Class); std::vector S; if (Match != C.InstanceMap.end()) { for (auto Instance: Match->second) { if (assignableTo(Ty, Instance->TypeExps[0]->getType())) { std::vector S; // TODO handle TApp // for (auto Arg: Ty->Args) { // TypeclassContext Classes; // // TODO // S.push_back(Classes); // } return S; } } } C.DE.add(Class, Ty, getSource()); // TODO handle TApp // for (auto Arg: Ty->Args) { // S.push_back({}); // } return S; } void propagateClasses(std::unordered_set& Classes, Type* Ty) { if (llvm::isa(Ty)) { auto TV = llvm::cast(Ty); for (auto Class: Classes) { TV->Contexts.emplace(Class); } } else if (llvm::isa(Ty)) { for (auto Class: Classes) { propagateClassTycon(Class, llvm::cast(Ty)); } } else if (!Classes.empty()) { C.DE.add(Ty, std::vector(Classes.begin(), Classes.end()), getSource()); } }; void propagateClassTycon(TypeclassId& Class, TCon* Ty) { auto S = findInstanceContext(Ty, Class); // TODO handle TApp // for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) { // propagateClasses(Classes, Arg); // } }; /** * Assign a type to a unification variable. * * If there are class constraints, those are propagated. * * If this type variable is solved during inference, it will be removed from * the inference context. * * Other side effects may occur. */ void join(TVar* TV, Type* Ty) { // std::cerr << describe(TV) << " => " << describe(Ty) << std::endl; TV->set(Ty); propagateClasses(TV->Contexts, Ty); // This is a very specific adjustment that is critical to the // well-functioning of the infer/unify algorithm. When addConstraint() is // called, it may decide to solve the constraint immediately during // inference. If this happens, a type variable might get assigned a concrete // type such as Int. We therefore never want the variable to be polymorphic // and be instantiated with a fresh variable, as that would allow Bool to // collide with Int. // // Should it get assigned another unification variable, that's OK too // because then that variable is what matters and it will become the new // (possibly polymorphic) variable. if (C.ActiveContext) { // std::cerr << "erase " << describe(TV) << std::endl; auto TVs = C.ActiveContext->TVs; TVs->erase(TV); } } }; bool Unifier::unify(Type* A, Type* B, bool DidSwap) { A = C.simplifyType(A); B = C.simplifyType(B); auto unifyError = [&]() { C.DE.add( C.simplifyType(Constraint->Left), C.simplifyType(Constraint->Right), LeftPath, RightPath, Constraint->Source ); }; auto pushLeft = [&](TypeIndex I) { if (DidSwap) { RightPath.push_back(I); } else { LeftPath.push_back(I); } }; auto popLeft = [&]() { if (DidSwap) { RightPath.pop_back(); } else { LeftPath.pop_back(); } }; auto pushRight = [&](TypeIndex I) { if (DidSwap) { LeftPath.push_back(I); } else { RightPath.push_back(I); } }; auto popRight = [&]() { if (DidSwap) { LeftPath.pop_back(); } else { RightPath.pop_back(); } }; auto swap = [&]() { std::swap(A, B); DidSwap = !DidSwap; }; auto unifyField = [&](Type* A, Type* B) { if (llvm::isa(A) && llvm::isa(B)) { return true; } if (llvm::isa(B)) { swap(); } if (llvm::isa(A)) { auto Present = static_cast(B); C.DE.add(CurrentFieldName, C.simplifyType(getLeft()), LeftPath, getSource()); return false; } auto Present1 = static_cast(A); auto Present2 = static_cast(B); return unify(Present1->Ty, Present2->Ty, DidSwap); }; if (llvm::isa(A) && llvm::isa(B)) { auto Var1 = static_cast(A); auto Var2 = static_cast(B); if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) { if (Var1->Id != Var2->Id) { unifyError(); return false; } return true; } TVar* To; TVar* From; if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) { To = Var1; From = Var2; } else { // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification // Either way, Var1, being Unification, is a good candidate for being unified away To = Var2; From = Var1; } if (From->Id != To->Id) { join(From, To); } return true; } if (llvm::isa(B)) { swap(); } if (llvm::isa(A)) { auto TV = static_cast(A); // Rigid type variables can never unify with antything else than what we // have already handled in the previous if-statement, so issue an error. if (TV->getVarKind() == VarKind::Rigid) { unifyError(); return false; } // Occurs check if (B->hasTypeVar(TV)) { // NOTE Just like GHC, we just display an error message indicating that // A cannot match B, e.g. a cannot match [a]. It looks much better // than obsure references to an occurs check unifyError(); return false; } join(TV, B); return true; } if (llvm::isa(A) && llvm::isa(B)) { auto C1 = ArrowCursor(static_cast(A), DidSwap ? RightPath : LeftPath); auto C2 = ArrowCursor(static_cast(B), DidSwap ? LeftPath : RightPath); bool Success = true; for (;;) { auto T1 = C1.next(); auto T2 = C2.next(); if (T1 == nullptr && T2 == nullptr) { break; } if (T1 == nullptr || T2 == nullptr) { unifyError(); Success = false; break; } if (!unify(T1, T2, DidSwap)) { Success = false; } } return Success; /* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */ /* return false; */ /* } */ /* auto Count = Arr1->ParamTypes.size(); */ /* for (std::size_t I = 0; I < Count; I++) { */ /* if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) { */ /* return false; */ /* } */ /* } */ /* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */ } if (llvm::isa(A) && llvm::isa(B)) { auto App1 = static_cast(A); auto App2 = static_cast(B); bool Success = true; if (!unify(App1->Op, App2->Op, DidSwap)) { Success = false; } if (!unify(App1->Arg, App2->Arg, DidSwap)) { Success = false; } return Success; } if (llvm::isa(B)) { swap(); } if (llvm::isa(A)) { auto Arr = static_cast(A); if (Arr->ParamTypes.empty()) { auto Success = unify(Arr->ReturnType, B, DidSwap); return Success; } } if (llvm::isa(A) && llvm::isa(B)) { auto Tuple1 = static_cast(A); auto Tuple2 = static_cast(B); if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) { unifyError(); return false; } auto Count = Tuple1->ElementTypes.size(); bool Success = true; for (size_t I = 0; I < Count; I++) { LeftPath.push_back(TypeIndex::forTupleElement(I)); RightPath.push_back(TypeIndex::forTupleElement(I)); if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); } return Success; } if (llvm::isa(A) || llvm::isa(B)) { // Type(s) could not be simplified at the beginning of this function, // so we have to re-visit the constraint when there is more information. C.Queue.push_back(Constraint); return true; } // if (llvm::isa(A) && llvm::isa(B)) { // auto Index1 = static_cast(A); // auto Index2 = static_cast(B); // return unify(Index1->Ty, Index2->Ty, Source); // } if (llvm::isa(A) && llvm::isa(B)) { auto Con1 = static_cast(A); auto Con2 = static_cast(B); if (Con1->Id != Con2->Id) { unifyError(); return false; } return true; } if (llvm::isa(A) && llvm::isa(B)) { return true; } if (llvm::isa(A) && llvm::isa(B)) { auto Field1 = static_cast(A); auto Field2 = static_cast(B); bool Success = true; if (Field1->Name == Field2->Name) { LeftPath.push_back(TypeIndex::forFieldType()); RightPath.push_back(TypeIndex::forFieldType()); CurrentFieldName = Field1->Name; if (!unifyField(Field1->Ty, Field2->Ty)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); LeftPath.push_back(TypeIndex::forFieldRest()); RightPath.push_back(TypeIndex::forFieldRest()); if (!unify(Field1->RestTy, Field2->RestTy, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); return Success; } auto NewRestTy = new TVar(C.NextTypeVarId++, VarKind::Unification); pushLeft(TypeIndex::forFieldRest()); if (!unify(Field1->RestTy, new TField(Field2->Name, Field2->Ty, NewRestTy), DidSwap)) { Success = false; } popLeft(); pushRight(TypeIndex::forFieldRest()); if (!unify(new TField(Field1->Name, Field1->Ty, NewRestTy), Field2->RestTy, DidSwap)) { Success = false; } popRight(); return Success; } if (llvm::isa(A) && llvm::isa(B)) { swap(); } if (llvm::isa(A) && llvm::isa(B)) { auto Field = static_cast(A); bool Success = true; pushLeft(TypeIndex::forFieldType()); CurrentFieldName = Field->Name; if (!unifyField(Field->Ty, new TAbsent)) { Success = false; } popLeft(); pushLeft(TypeIndex::forFieldRest()); if (!unify(Field->RestTy, B, DidSwap)) { Success = false; } popLeft(); return Success; } unifyError(); return false; } void Checker::solveEqual(CEqual* C) { // std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; Unifier A { *this, C }; A.unify(); } }