#include #include #include #include "zen/config.hpp" #include "bolt/Type.hpp" #include "bolt/CSTVisitor.hpp" #include "bolt/DiagnosticEngine.hpp" #include "bolt/Diagnostics.hpp" #include "bolt/CST.hpp" #include "bolt/Checker.hpp" namespace bolt { Constraint* Constraint::substitute(const TVSub &Sub) { switch (Kind) { case ConstraintKind::Equal: { auto Equal = static_cast(this); return new CEqual(Equal->Left->substitute(Sub), Equal->Right->substitute(Sub), Equal->Source); } case ConstraintKind::Many: { auto Many = static_cast(this); auto NewConstraints = new ConstraintSet(); for (auto Element: Many->Elements) { NewConstraints->push_back(Element->substitute(Sub)); } return new CMany(*NewConstraints); } case ConstraintKind::Field: { auto Field = static_cast(this); auto NewTupleTy = Field->TupleTy->substitute(Sub); auto NewFieldTy = Field->FieldTy->substitute(Sub); return new CField(NewTupleTy, Field->I, NewFieldTy, Field->Source); } case ConstraintKind::Empty: return this; } ZEN_UNREACHABLE } Type* Checker::solveType(Type* Ty) { return Ty->rewrite([this](auto Ty) { return Ty->find(); }, true); } Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): Config(Config), DE(DE) { BoolType = createConType("Bool"); IntType = createConType("Int"); StringType = createConType("String"); ListType = createConType("List"); UnitType = new Type(TTuple({})); } Scheme* Checker::lookup(ByteString Name, SymKind Kind) { auto Curr = &getContext(); for (;;) { auto Match = Curr->Env.lookup(Name, Kind); if (Match != nullptr) { return Match; } Curr = Curr->Parent; if (!Curr) { break; } } return nullptr; } Type* Checker::lookupMono(ByteString Name, SymKind Kind) { auto Scm = lookup(Name, Kind); if (Scm == nullptr) { return nullptr; } auto F = static_cast(Scm); ZEN_ASSERT(F->TVs == nullptr || F->TVs->empty()); return F->Type; } void Checker::addBinding(ByteString Name, Scheme* Scm, SymKind Kind) { getContext().Env.add(Name, Scm, Kind); } Type* Checker::getReturnType() { auto Ty = getContext().ReturnType; ZEN_ASSERT(Ty != nullptr); return Ty; } static bool hasTypeVar(TVSet& Set, Type* Type) { for (auto TV: Type->getTypeVars()) { if (Set.count(TV)) { return true; } } return false; } void Checker::setContext(InferContext* Ctx) { ActiveContext = Ctx; } void Checker::popContext() { ZEN_ASSERT(ActiveContext); ActiveContext = ActiveContext->Parent; } InferContext& Checker::getContext() { ZEN_ASSERT(ActiveContext); return *ActiveContext; } void Checker::makeEqual(Type* A, Type* B, Node* Source) { addConstraint(new CEqual(A, B, Source)); } void Checker::addConstraint(Constraint* C) { switch (C->getKind()) { case ConstraintKind::Field: // FIXME Check if this is all that needs to be done getContext().Constraints->push_back(C); break; case ConstraintKind::Equal: { auto Y = static_cast(C); // This will store all inference contexts in Contexts, from most local // one to most general one. Because this order is not ideal, the code // below will have to handle that. auto Curr = &getContext(); std::vector Contexts; for (;;) { Contexts.push_back(Curr); Curr = Curr->Parent; if (!Curr) { break; } } std::size_t Global = Contexts.size()-1; // If no MaxLevelLeft was found, that means that not a single // corresponding type variable was found in the contexts. We set it to // Contexts.size()-1, which corresponds to the global inference context. std::size_t MaxLevelLeft = Global; for (std::size_t I = 0; I < Global; I++) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Left)) { MaxLevelLeft = I; break; } } // Same as above but now mirrored for Y->Right std::size_t MaxLevelRight = Global; for (std::size_t I = 0; I < Global; I++) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Right)) { MaxLevelRight = I; break; } } // The lowest index is determined by the one that has no type variables // in Y->Left AND in Y->Right. This implies max() must be used, so that // the very first enounter of a type variable matters. auto UpperLevel = std::max(MaxLevelLeft, MaxLevelRight); // Now find the lowest index LowerLevel such that all the contexts that are more // local do not contain any type variables that are present in the // equality constraint. std::size_t LowerLevel = UpperLevel; for (std::size_t I = Global; I-- > 0; ) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { LowerLevel = I; break; } } if (UpperLevel == LowerLevel || MaxLevelLeft == Global || MaxLevelRight == Global) { unify(Y->Left, Y->Right, Y->Source); } else { Contexts[UpperLevel]->Constraints->push_back(C); } break; } case ConstraintKind::Many: { auto Y = static_cast(C); for (auto Element: Y->Elements) { addConstraint(Element); } break; } case ConstraintKind::Empty: break; } } void Checker::forwardDeclare(Node* X) { switch (X->getKind()) { case NodeKind::ExpressionStatement: case NodeKind::ReturnStatement: case NodeKind::IfStatement: break; case NodeKind::SourceFile: { auto File = static_cast(X); for (auto Element: File->Elements) { forwardDeclare(Element) ; } break; } case NodeKind::ClassDeclaration: { auto Class = static_cast(X); // for (auto TE: Class->TypeVars) { // auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText()); // // TV->Contexts.emplace(Class->Name->getCanonicalText()); // TE->setType(TV); // } for (auto Element: Class->Elements) { forwardDeclare(Element); } break; } case NodeKind::InstanceDeclaration: { auto Decl = static_cast(X); // Needed to set the associated Type on the CST node for (auto TE: Decl->TypeExps) { inferTypeExpression(TE); } auto Match = InstanceMap.find(getCanonicalText(Decl->Name)); if (Match == InstanceMap.end()) { InstanceMap.emplace(getCanonicalText(Decl->Name), std::vector { Decl }); } else { Match->second.push_back(Decl); } for (auto Element: Decl->Elements) { forwardDeclare(Element); } break; } case NodeKind::LetDeclaration: { // Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression() auto Decl = static_cast(X); if (!Decl->isVariable()) { break; } Type* Ty; if (Decl->TypeAssert) { Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression); } else { Ty = createTypeVar(); } Decl->setType(Ty); break; } case NodeKind::VariantDeclaration: { auto Decl = static_cast(X); setContext(Decl->Ctx); std::vector Vars; for (auto TE: Decl->TVs) { auto TV = createRigidVar(getCanonicalText(TE->Name)); Decl->Ctx->TVs->emplace(TV); Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); Vars.push_back(TV); } Type* Ty = createConType(getCanonicalText(Decl->Name)); // Build the type that is actually returned by constructor functions auto RetTy = Ty; for (auto Var: Vars) { RetTy = new Type(TApp(RetTy, Var)); } // Must be added early so we can create recursive types Decl->Ctx->Parent->Env.add(getCanonicalText(Decl->Name), new Forall(Ty), SymKind::Type); for (auto Member: Decl->Members) { switch (Member->getKind()) { case NodeKind::TupleVariantDeclarationMember: { auto TupleMember = static_cast(Member); std::vector ParamTypes; for (auto Element: TupleMember->Elements) { // inferTypeExpression will look up any TVars that were part of the signature of Decl ParamTypes.push_back(inferTypeExpression(Element, false)); } Decl->Ctx->Parent->Env.add( getCanonicalText(TupleMember->Name), new Forall( Decl->Ctx->TVs, Decl->Ctx->Constraints, Type::buildArrow(ParamTypes, RetTy) ), SymKind::Var ); break; } case NodeKind::RecordVariantDeclarationMember: { // TODO break; } default: ZEN_UNREACHABLE } } popContext(); break; } case NodeKind::RecordDeclaration: { auto Decl = static_cast(X); setContext(Decl->Ctx); std::vector Vars; for (auto TE: Decl->Vars) { auto TV = createRigidVar(getCanonicalText(TE->Name)); Decl->Ctx->TVs->emplace(TV); Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); Vars.push_back(TV); } auto Name = getCanonicalText(Decl->Name); auto Ty = createConType(Name); // Must be added early so we can create recursive types Decl->Ctx->Parent->Env.add(Name, new Forall(Ty), SymKind::Type); Type* RetTy = Ty; for (auto TV: Vars) { RetTy = new Type(TApp(RetTy, TV)); } // Corresponds to the logic of one branch of a VariantDeclarationMember Type* FieldsTy = new Type(TNil()); for (auto Field: Decl->Fields) { FieldsTy = new Type( TField( getCanonicalText(Field->Name), new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))), FieldsTy ) ); } Decl->Ctx->Parent->Env.add( Name, new Forall( Decl->Ctx->TVs, Decl->Ctx->Constraints, new Type(TArrow(FieldsTy, RetTy)) ), SymKind::Var ); popContext(); break; } default: ZEN_UNREACHABLE } } void Checker::initialize(Node* N) { struct Init : public CSTVisitor { Checker& C; std::stack Contexts; InferContext* createDerivedContext() { return C.createInferContext(Contexts.top()); } void visitVariantDeclaration(VariantDeclaration* Decl) { Decl->Ctx = createDerivedContext(); } void visitRecordDeclaration(RecordDeclaration* Decl) { Decl->Ctx = createDerivedContext(); } void visitMatchCase(MatchCase* C) { C->Ctx = createDerivedContext(); Contexts.push(C->Ctx); visitEachChild(C); Contexts.pop(); } void visitSourceFile(SourceFile* SF) { SF->Ctx = C.createInferContext(); Contexts.push(SF->Ctx); visitEachChild(SF); Contexts.pop(); } void visitLetDeclaration(LetDeclaration* Let) { if (Let->isFunction()) { Let->Ctx = createDerivedContext(); Contexts.push(Let->Ctx); visitEachChild(Let); Contexts.pop(); } } // void visitVariableDeclaration(VariableDeclaration* Var) { // Var->Ctx = Contexts.top(); // visitEachChild(Var); // } }; Init I { {}, *this }; I.visit(N); } void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { if (!Let->isFunction()) { return; } // std::cerr << "declare " << Let->getNameAsString() << std::endl; setContext(Let->Ctx); auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { auto Id = getCanonicalText(Class->Name); auto Ctx = &getContext(); std::vector Out; for (auto TE: Class->TypeVars) { auto Name = getCanonicalText(TE->Name); auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); TV->asVar().Context.emplace(Id); Ctx->Env.add(Name, new Forall(TV), SymKind::Type); Out.push_back(TV); } return Out; }; // If declaring a let-declaration inside a type class declaration, // we need to mark that the let-declaration requires this class. // This marking is set on the rigid type variables of the class, which // are then added to this local type environment. if (Let->isClass()) { addClassVars(static_cast(Let->Parent), true); } // Here we infer the primary type of the let declaration. If there's a // type assert, that assert should be authoritative so we use that. // Otherwise, the type is not further specified and we create a new // unification variable. Type* Ty; if (Let->TypeAssert) { Ty = inferTypeExpression(Let->TypeAssert->TypeExpression); } else { Ty = createTypeVar(); } Let->setType(Ty); // If declaring a let-declaration inside a type instance declaration, // we need to perform some work to make sure the type asserts of the // corresponding let-declaration in the type class declaration are // accounted for. if (Let->isInstance()) { auto Instance = static_cast(Let->Parent); auto Class = cast(Instance->getScope()->lookup({ {}, getCanonicalText(Instance->Name) }, SymbolKind::Class)); // TODO check if `Class` is nullptr auto SigLet = cast(Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var)); auto Params = addClassVars(Class, false); // The type asserts in the type class declaration might make use of // the type parameters of the type class declaration, so it is // important to make them available in the type environment. Moreover, // we will be unifying them with the actual types declared in the // instance declaration, so we keep track of them. // std::vector Params; // TVSub Sub; // for (auto TE: Class->TypeVars) { // auto TV = createTypeVar(); // Sub.emplace(cast(TE->getType()), TV); // Params.push_back(TV); // } // Here we do the actual unification of e.g. Eq a with Eq Bool. The // unification variables we created previously will be unified with // e.g. Bool, which causes the type assert to also collapse to e.g. // Bool -> Bool -> Bool. for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { makeEqual(Param, TE->getType(), TE); } // It would be very strange if there was no type assert in the type // class let-declaration but we rather not let the compiler crash if that happens. if (SigLet->TypeAssert) { // Note that we can't do SigLet->TypeAssert->TypeExpression->getType() // because we need to re-generate the type within the local context of // this let-declaration. // TODO make CEqual accept multiple nodes makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let); } } if (Let->Body) { switch (Let->Body->getKind()) { case NodeKind::LetExprBody: break; case NodeKind::LetBlockBody: { auto Block = static_cast(Let->Body); Let->Ctx->ReturnType = createTypeVar(); for (auto Element: Block->Elements) { forwardDeclare(Element); } break; } default: ZEN_UNREACHABLE } } if (!Let->isInstance()) { Let->Ctx->Parent->Env.add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty), SymKind::Var); } } void Checker::inferFunctionDeclaration(LetDeclaration* Decl) { if (!Decl->isFunction()) { return; } // std::cerr << "infer " << Decl->getNameAsString() << std::endl; auto OldCtx = ActiveContext; setContext(Decl->Ctx); std::vector ParamTypes; Type* RetType; for (auto Param: Decl->Params) { ParamTypes.push_back(inferPattern(Param->Pattern)); } if (Decl->Body) { switch (Decl->Body->getKind()) { case NodeKind::LetExprBody: { auto Expr = static_cast(Decl->Body); RetType = inferExpression(Expr->Expression); break; } case NodeKind::LetBlockBody: { auto Block = static_cast(Decl->Body); RetType = Decl->Ctx->ReturnType; for (auto Element: Block->Elements) { infer(Element); } break; } default: ZEN_UNREACHABLE } } else { RetType = createTypeVar(); } makeEqual(Decl->getType(), Type::buildArrow(ParamTypes, RetType), Decl); setContext(OldCtx); } void Checker::infer(Node* N) { switch (N->getKind()) { case NodeKind::SourceFile: { auto File = static_cast(N); for (auto Element: File->Elements) { infer(Element); } break; } case NodeKind::ClassDeclaration: { auto Decl = static_cast(N); for (auto Element: Decl->Elements) { infer(Element); } break; } case NodeKind::InstanceDeclaration: { auto Decl = static_cast(N); for (auto Element: Decl->Elements) { infer(Element); } break; } case NodeKind::VariantDeclaration: case NodeKind::RecordDeclaration: // Nothing to do for a type-level declaration break; case NodeKind::IfStatement: { auto IfStmt = static_cast(N); for (auto Part: IfStmt->Parts) { if (Part->Test != nullptr) { makeEqual(BoolType, inferExpression(Part->Test), Part->Test); } for (auto Element: Part->Elements) { infer(Element); } } break; } case NodeKind::ReturnStatement: { auto RetStmt = static_cast(N); Type* ReturnType; if (RetStmt->Expression) { makeEqual(inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression); } else { ReturnType = UnitType; makeEqual(UnitType, getReturnType(), N); } break; } case NodeKind::LetDeclaration: { // Function declarations are handled separately in inferFunctionDeclaration() auto Decl = static_cast(N); if (Decl->Visited) { break; } if (Decl->isFunction()) { Decl->IsCycleActive = true; Decl->Visited = true; inferFunctionDeclaration(Decl); Decl->IsCycleActive = false; } else if (Decl->isVariable()) { auto Ty = Decl->getType(); if (Decl->Body) { ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); auto E = static_cast(Decl->Body); auto Ty2 = inferExpression(E->Expression); makeEqual(Ty, Ty2, Decl); } auto Ty3 = inferPattern(Decl->Pattern); makeEqual(Ty, Ty3, Decl); } break; } case NodeKind::ExpressionStatement: { auto ExprStmt = static_cast(N); inferExpression(ExprStmt->Expression); break; } default: ZEN_UNREACHABLE } } Type* Checker::createConType(ByteString Name) { return new Type(TCon(NextConTypeId++, Name)); } Type* Checker::createRigidVar(ByteString Name) { auto TV = new Type(TVar(VarKind::Rigid, NextTypeVarId++, {}, Name, {{}})); getContext().TVs->emplace(TV); return TV; } Type* Checker::createTypeVar() { auto TV = new Type(TVar(VarKind::Unification, NextTypeVarId++, {})); getContext().TVs->emplace(TV); return TV; } InferContext* Checker::createInferContext(InferContext* Parent, TVSet* TVs, ConstraintSet* Constraints) { auto Ctx = new InferContext; Ctx->Parent = Parent; Ctx->TVs = new TVSet; Ctx->Constraints = new ConstraintSet; return Ctx; } Type* Checker::instantiate(Scheme* Scm, Node* Source) { switch (Scm->getKind()) { case SchemeKind::Forall: { auto F = static_cast(Scm); TVSub Sub; for (auto TV: *F->TVs) { auto Fresh = createTypeVar(); // std::cerr << describe(TV) << " => " << describe(Fresh) << std::endl; Fresh->asVar().Context = TV->asVar().Context; Sub[TV] = Fresh; } for (auto Constraint: *F->Constraints) { // FIXME improve this if (Constraint->getKind() == ConstraintKind::Equal) { auto Eq = static_cast(Constraint); Eq->Left = solveType(Eq->Left); Eq->Right = solveType(Eq->Right); } auto NewConstraint = Constraint->substitute(Sub); // This makes error messages prettier by relating the typing failure // to the call site rather than the definition. if (NewConstraint->getKind() == ConstraintKind::Equal) { auto Eq = static_cast(Constraint); Eq->Source = Source; } addConstraint(NewConstraint); } // This call to solve happens because constraints may have already // been solved, with some unification variables being erased. To make // sure we instantiate unification variables that are still in use // we solve before substituting. return solveType(F->Type)->substitute(Sub); } } ZEN_UNREACHABLE } void Checker::inferConstraintExpression(ConstraintExpression* C) { switch (C->getKind()) { case NodeKind::TypeclassConstraintExpression: { auto D = static_cast(C); std::vector Types; for (auto TE: D->TEs) { auto Ty = inferTypeExpression(TE); Ty->asVar().Provided->emplace(getCanonicalText(D->Name)); Types.push_back(Ty); } break; } case NodeKind::EqualityConstraintExpression: { auto D = static_cast(C); makeEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C); break; } default: ZEN_UNREACHABLE } } Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) { switch (N->getKind()) { case NodeKind::ReferenceTypeExpression: { auto RefTE = static_cast(N); auto Scm = lookup(getCanonicalText(RefTE->Name), SymKind::Type); Type* Ty; if (Scm == nullptr) { DE.add(getCanonicalText(RefTE->Name), RefTE->Name); Ty = createTypeVar(); } else { Ty = instantiate(Scm, RefTE); } N->setType(Ty); return Ty; } case NodeKind::AppTypeExpression: { auto AppTE = static_cast(N); Type* Ty = inferTypeExpression(AppTE->Op, AutoVars); for (auto Arg: AppTE->Args) { Ty = new Type(TApp(Ty, inferTypeExpression(Arg, AutoVars))); } N->setType(Ty); return Ty; } case NodeKind::VarTypeExpression: { auto VarTE = static_cast(N); auto Ty = lookupMono(getCanonicalText(VarTE->Name), SymKind::Type); if (Ty == nullptr) { if (!AutoVars || Config.typeVarsRequireForall()) { DE.add(getCanonicalText(VarTE->Name), VarTE->Name); } Ty = createRigidVar(getCanonicalText(VarTE->Name)); addBinding(getCanonicalText(VarTE->Name), new Forall(Ty), SymKind::Type); } ZEN_ASSERT(Ty->isVar()); N->setType(Ty); return Ty; } case NodeKind::RecordTypeExpression: { auto RecTE = static_cast(N); auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil()); for (auto [Field, Comma]: RecTE->Fields) { Ty = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty)); } N->setType(Ty); return Ty; } case NodeKind::TupleTypeExpression: { auto TupleTE = static_cast(N); std::vector ElementTypes; for (auto [TE, Comma]: TupleTE->Elements) { ElementTypes.push_back(inferTypeExpression(TE, AutoVars)); } auto Ty = new Type(TTuple(ElementTypes)); N->setType(Ty); return Ty; } case NodeKind::NestedTypeExpression: { auto NestedTE = static_cast(N); auto Ty = inferTypeExpression(NestedTE->TE, AutoVars); N->setType(Ty); return Ty; } case NodeKind::ArrowTypeExpression: { auto ArrowTE = static_cast(N); std::vector ParamTypes; for (auto ParamType: ArrowTE->ParamTypes) { ParamTypes.push_back(inferTypeExpression(ParamType, AutoVars)); } auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, AutoVars); auto Ty = Type::buildArrow(ParamTypes, ReturnType); N->setType(Ty); return Ty; } case NodeKind::QualifiedTypeExpression: { auto QTE = static_cast(N); for (auto [C, Comma]: QTE->Constraints) { inferConstraintExpression(C); } auto Ty = inferTypeExpression(QTE->TE, AutoVars); N->setType(Ty); return Ty; } default: ZEN_UNREACHABLE } } Type* sortRow(Type* Ty) { std::map Fields; while (Ty->isField()) { auto& Field = Ty->asField(); Fields.emplace(Field.Name, Ty); Ty = Field.RestTy; } for (auto [Name, Field]: Fields) { Ty = new Type(TField(Name, Field->asField().Ty, Ty)); } return Ty; } Type* Checker::inferExpression(Expression* X) { Type* Ty; for (auto A: X->Annotations) { if (A->getKind() == NodeKind::TypeAssertAnnotation) { inferTypeExpression(static_cast(A)->TE); } } switch (X->getKind()) { case NodeKind::MatchExpression: { auto Match = static_cast(X); Type* ValTy; if (Match->Value) { ValTy = inferExpression(Match->Value); } else { ValTy = createTypeVar(); } Ty = createTypeVar(); for (auto Case: Match->Cases) { auto OldCtx = &getContext(); setContext(Case->Ctx); auto PattTy = inferPattern(Case->Pattern); makeEqual(PattTy, ValTy, Case); auto ExprTy = inferExpression(Case->Expression); makeEqual(ExprTy, Ty, Case->Expression); setContext(OldCtx); } if (!Match->Value) { Ty = new Type(TArrow(ValTy, Ty)); } break; } case NodeKind::RecordExpression: { auto Record = static_cast(X); Ty = new Type(TNil()); for (auto [Field, Comma]: Record->Fields) { Ty = new Type(TField( getCanonicalText(Field->Name), new Type(TPresent(inferExpression(Field->getExpression()))), Ty )); } Ty = sortRow(Ty); break; } case NodeKind::LiteralExpression: { auto Const = static_cast(X); Ty = inferLiteral(Const->Token); break; } case NodeKind::ReferenceExpression: { auto Ref = static_cast(X); ZEN_ASSERT(Ref->ModulePath.empty()); if (Ref->Name->is()) { auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); if (!Scm) { DE.add(getCanonicalText(Ref->Name), Ref->Name); Ty = createTypeVar(); break; } Ty = instantiate(Scm, X); break; } auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); if (!Target) { DE.add(getCanonicalText(Ref->Name), Ref->Name); Ty = createTypeVar(); break; } if (Target->getKind() == NodeKind::LetDeclaration) { auto Let = static_cast(Target); if (Let->IsCycleActive) { Ty = Let->getType(); break; } if (!Let->Visited) { infer(Let); } } auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); ZEN_ASSERT(Scm); Ty = instantiate(Scm, X); break; } case NodeKind::CallExpression: { auto Call = static_cast(X); auto OpTy = inferExpression(Call->Function); Ty = createTypeVar(); std::vector ArgTypes; for (auto Arg: Call->Args) { ArgTypes.push_back(inferExpression(Arg)); } makeEqual(OpTy, Type::buildArrow(ArgTypes, Ty), X); break; } case NodeKind::InfixExpression: { auto Infix = static_cast(X); auto Scm = lookup(Infix->Operator->getText(), SymKind::Var); if (Scm == nullptr) { DE.add(Infix->Operator->getText(), Infix->Operator); Ty = createTypeVar(); break; } auto OpTy = instantiate(Scm, Infix->Operator); Ty = createTypeVar(); std::vector ArgTys; ArgTys.push_back(inferExpression(Infix->Left)); ArgTys.push_back(inferExpression(Infix->Right)); makeEqual(Type::buildArrow(ArgTys, Ty), OpTy, X); break; } case NodeKind::TupleExpression: { auto Tuple = static_cast(X); std::vector Types; for (auto [E, Comma]: Tuple->Elements) { Types.push_back(inferExpression(E)); } Ty = new Type(TTuple(Types)); break; } case NodeKind::MemberExpression: { auto Member = static_cast(X); auto ExprTy = inferExpression(Member->E); switch (Member->Name->getKind()) { case NodeKind::IntegerLiteral: { auto I = static_cast(Member->Name); Ty = createTypeVar(); addConstraint(new CField(ExprTy, I->asInt(), Ty, Member)); break; } case NodeKind::Identifier: { auto K = static_cast(Member->Name); Ty = createTypeVar(); auto RestTy = createTypeVar(); makeEqual(new Type(TField(getCanonicalText(K), Ty, RestTy)), ExprTy, Member); break; } default: ZEN_UNREACHABLE } break; } case NodeKind::NestedExpression: { auto Nested = static_cast(X); Ty = inferExpression(Nested->Inner); break; } default: ZEN_UNREACHABLE } // Ty = find(Ty); X->setType(Ty); return Ty; } RecordPatternField* getRestField(std::vector> Fields) { for (auto [Field, Comma]: Fields) { if (Field->DotDot) { return Field; } } return nullptr; } Type* Checker::inferPattern( Pattern* Pattern, ConstraintSet* Constraints, TVSet* TVs ) { switch (Pattern->getKind()) { case NodeKind::BindPattern: { auto P = static_cast(Pattern); auto Ty = createTypeVar(); addBinding(getCanonicalText(P->Name), new Forall(TVs, Constraints, Ty), SymKind::Var); return Ty; } case NodeKind::NamedTuplePattern: { auto P = static_cast(Pattern); auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); std::vector ElementTypes; for (auto P2: P->Patterns) { ElementTypes.push_back(inferPattern(P2, Constraints, TVs)); } if (!Scm) { DE.add(getCanonicalText(P->Name), P->Name); return createTypeVar(); } auto Ty = instantiate(Scm, P); auto RetTy = createTypeVar(); makeEqual(Ty, Type::buildArrow(ElementTypes, RetTy), P); return RetTy; } case NodeKind::RecordPattern: { auto P = static_cast(Pattern); auto RestField = getRestField(P->Fields); Type* RecordTy; if (RestField == nullptr) { RecordTy = new Type(TNil()); } else if (RestField->Pattern) { RecordTy = inferPattern(RestField->Pattern); } else { RecordTy = createTypeVar(); } for (auto [Field, Comma]: P->Fields) { if (Field->DotDot) { continue; } Type* FieldTy; if (Field->Pattern) { FieldTy = inferPattern(Field->Pattern, Constraints, TVs); } else { FieldTy = createTypeVar(); addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); } RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); } return RecordTy; } case NodeKind::NamedRecordPattern: { auto P = static_cast(Pattern); auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); if (Scm == nullptr) { DE.add(getCanonicalText(P->Name), P->Name); return createTypeVar(); } auto RestField = getRestField(P->Fields); Type* RecordTy; if (RestField == nullptr) { RecordTy = new Type(TNil()); } else if (RestField->Pattern) { RecordTy = inferPattern(RestField->Pattern); } else { RecordTy = createTypeVar(); } for (auto [Field, Comma]: P->Fields) { if (Field->DotDot) { continue; } Type* FieldTy; if (Field->Pattern) { FieldTy = inferPattern(Field->Pattern, Constraints, TVs); } else { FieldTy = createTypeVar(); addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); } RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); } auto Ty = instantiate(Scm, P); auto RetTy = createTypeVar(); makeEqual(Ty, new Type(TArrow(RecordTy, RetTy)), P); return RetTy; } case NodeKind::TuplePattern: { auto P = static_cast(Pattern); std::vector ElementTypes; for (auto [Element, Comma]: P->Elements) { ElementTypes.push_back(inferPattern(Element)); } return new Type(TTuple(ElementTypes)); } case NodeKind::ListPattern: { auto P = static_cast(Pattern); auto ElementType = createTypeVar(); for (auto [Element, Separator]: P->Elements) { makeEqual(ElementType, inferPattern(Element), P); } return new Type(TApp(ListType, ElementType)); } case NodeKind::NestedPattern: { auto P = static_cast(Pattern); return inferPattern(P->P, Constraints, TVs); } case NodeKind::LiteralPattern: { auto P = static_cast(Pattern); return inferLiteral(P->Literal); } default: ZEN_UNREACHABLE } } Type* Checker::inferLiteral(Literal* L) { Type* Ty; switch (L->getKind()) { case NodeKind::IntegerLiteral: Ty = lookupMono("Int", SymKind::Type); break; case NodeKind::StringLiteral: Ty = lookupMono("String", SymKind::Type); break; default: ZEN_UNREACHABLE } ZEN_ASSERT(Ty != nullptr); return Ty; } void Checker::populate(SourceFile* SF) { struct Visitor : public CSTVisitor { Graph& RefGraph; std::stack Stack; void visitLetDeclaration(LetDeclaration* N) { RefGraph.addVertex(N); Stack.push(N); visitEachChild(N); Stack.pop(); } void visitReferenceExpression(ReferenceExpression* N) { auto Y = static_cast(N); auto Def = Y->getScope()->lookup(Y->getSymbolPath()); // Name lookup failures will be reported directly in inferExpression(). if (Def == nullptr || Def->getKind() != NodeKind::LetDeclaration) { return; } // This case ensures that a deeply nested structure that references a // parameter of a parent node but is not referenced itself is correctly handled. // Note that the edge goes from the parent let to the parameter. This is normal. if (Def->getKind() == NodeKind::Parameter) { RefGraph.addEdge(Stack.top(), Def->Parent); return; } if (!Stack.empty()) { RefGraph.addEdge(Def, Stack.top()); } } }; Visitor V { {}, RefGraph }; V.visit(SF); } Type* Checker::getType(TypedNode *Node) { auto Ty = Node->getType(); if (Node->Flags & NodeFlags_TypeIsSolved) { return Ty; } Ty = solveType(Ty); Node->setType(Ty); Node->Flags |= NodeFlags_TypeIsSolved; return Ty; } void Checker::check(SourceFile *SF) { initialize(SF); setContext(SF->Ctx); addBinding("String", new Forall(StringType), SymKind::Type); addBinding("Int", new Forall(IntType), SymKind::Type); addBinding("Bool", new Forall(BoolType), SymKind::Type); addBinding("List", new Forall(ListType), SymKind::Type); addBinding("True", new Forall(BoolType), SymKind::Var); addBinding("False", new Forall(BoolType), SymKind::Var); auto A = createTypeVar(); addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)), SymKind::Var); addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); populate(SF); forwardDeclare(SF); auto SCCs = RefGraph.strongconnect(); for (auto Nodes: SCCs) { auto TVs = new TVSet; auto Constraints = new ConstraintSet; for (auto N: Nodes) { if (N->getKind() != NodeKind::LetDeclaration) { continue; } auto Decl = static_cast(N); forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); } } setContext(SF->Ctx); infer(SF); // Important because otherwise some logic for some optimisations will kick in that are no longer active. ActiveContext = nullptr; solve(new CMany(*SF->Ctx->Constraints)); class Visitor : public CSTVisitor { Checker& C; public: Visitor(Checker& C): C(C) {} void visitAnnotation(Annotation* A) { } void visitExpression(Expression* X) { C.getType(X); } } V(*this); V.visit(SF); } void Checker::solve(Constraint* Constraint) { Queue.push_back(Constraint); bool DidJoin = false; std::deque NextQueue; while (true) { if (Queue.empty()) { if (NextQueue.empty() || !DidJoin) { break; } DidJoin = false; std::swap(Queue, NextQueue); } auto Constraint = Queue.front(); Queue.pop_front(); switch (Constraint->getKind()) { case ConstraintKind::Empty: break; case ConstraintKind::Field: { auto Field = static_cast(Constraint); auto MaybeTuple = Field->TupleTy->find(); if (MaybeTuple->isTuple()) { auto& Tuple = MaybeTuple->asTuple(); if (Field->I >= Tuple.ElementTypes.size()) { DE.add(MaybeTuple, Field->I, Field->Source); } else { auto ElementTy = Tuple.ElementTypes[Field->I]; unify(ElementTy, Field->FieldTy, Field->Source); } } else if (MaybeTuple->isVar()) { NextQueue.push_back(Constraint); } else { DE.add(MaybeTuple, Field->Source); } break; } case ConstraintKind::Many: { auto Many = static_cast(Constraint); for (auto Constraint: Many->Elements) { Queue.push_back(Constraint); } break; } case ConstraintKind::Equal: { auto Equal = static_cast(Constraint); if (unify(Equal->Left, Equal->Right, Equal->Source)) { DidJoin = true; } break; } } } } bool assignableTo(Type* A, Type* B) { if (A->isCon() && B->isCon()) { auto& Con1 = A->asCon(); auto& Con2 = B->asCon(); if (Con1.Id != Con2.Id) { return false; } return true; } // TODO must handle a TApp ZEN_UNREACHABLE } class ArrowCursor { /// Types on this stack are guaranteed to be arrow types. std::stack> Stack; TypePath& Path; std::size_t I; public: ArrowCursor(Type* Arr, TypePath& Path): Path(Path) { Stack.push({ Arr, true }); Path.push_back(Arr->getStartIndex()); } Type* next() { while (!Stack.empty()) { auto& [Arrow, First] = Stack.top(); auto& Index = Path.back(); if (!First) { Index.advance(Arrow); } else { First = false; } Type* Ty; if (Index == Arrow->getEndIndex()) { Path.pop_back(); Stack.pop(); continue; } Ty = Arrow->resolve(Index); if (Ty->isArrow()) { auto NewIndex = Arrow->getStartIndex(); Stack.push({ Ty, true }); Path.push_back(NewIndex); } else { return Ty; } } return nullptr; } }; struct Unifier { Checker& C; // CEqual* Constraint; Type* Left; Type* Right; Node* Source; // Internal state used by the unifier ByteString CurrentFieldName; TypePath LeftPath; TypePath RightPath; bool DidJoin = false; Type* getLeft() const { return Left; } Type* getRight() const { return Right; } Node* getSource() const { return Source; } bool unifyField(Type* A, Type* B, bool DidSwap); bool unify(Type* A, Type* B, bool DidSwap); bool unify() { return unify(Left, Right, false); } std::vector findInstanceContext(const TypeSig& Ty, TypeclassId& Class) { auto Match = C.InstanceMap.find(Class); std::vector S; if (Match != C.InstanceMap.end()) { for (auto Instance: Match->second) { if (assignableTo(Ty.Orig, Instance->TypeExps[0]->getType())) { std::vector S; for (auto Arg: Ty.Args) { TypeclassContext Classes; // TODO S.push_back(Classes); } return S; } } } C.DE.add(Class, Ty.Orig, getSource()); for (auto Arg: Ty.Args) { S.push_back({}); } return S; } TypeSig getTypeSig(Type* Ty) { Type* Op = nullptr; std::vector Args; std::function Visit = [&](Type* Ty) { if (Ty->isApp()) { Visit(Ty->asApp().Op); Visit(Ty->asApp().Arg); } else if (!Op) { Op = Ty; } else { Args.push_back(Ty); } }; Visit(Ty); return TypeSig { Ty, Op, Args }; } void propagateClasses(std::unordered_set& Classes, Type* Ty) { if (Ty->isVar()) { auto TV = Ty->asVar(); for (auto Class: Classes) { TV.Context.emplace(Class); } if (TV.isRigid()) { for (auto Id: TV.Context) { if (!TV.Provided->count(Id)) { C.DE.add(TypeclassSignature { Id, { Ty } }, getSource()); } } } } else if (Ty->isCon() || Ty->isApp()) { auto Sig = getTypeSig(Ty); for (auto Class: Classes) { propagateClassTycon(Class, Sig); } } else if (!Classes.empty()) { C.DE.add(Ty, std::vector(Classes.begin(), Classes.end()), getSource()); } }; void propagateClassTycon(TypeclassId& Class, const TypeSig& Sig) { auto S = findInstanceContext(Sig, Class); for (auto [Classes, Arg]: zen::zip(S, Sig.Args)) { propagateClasses(Classes, Arg); } }; /** * Assign a type to a unification variable. * * If there are class constraints, those are propagated. * * If this type variable is solved during inference, it will be removed from * the inference context. * * Other side effects may occur. */ void join(Type* TV, Type* Ty) { // std::cerr << describe(TV) << " => " << describe(Ty) << std::endl; TV->set(Ty); DidJoin = true; propagateClasses(TV->asVar().Context, Ty); // This is a very specific adjustment that is critical to the // well-functioning of the infer/unify algorithm. When addConstraint() is // called, it may decide to solve the constraint immediately during // inference. If this happens, a type variable might get assigned a concrete // type such as Int. We therefore never want the variable to be polymorphic // and be instantiated with a fresh variable, as that would allow Bool to // collide with Int. // // Should it get assigned another unification variable, that's OK too // because then that variable is what matters and it will become the new // (possibly polymorphic) variable. if (C.ActiveContext) { // std::cerr << "erase " << describe(TV) << std::endl; auto TVs = C.ActiveContext->TVs; TVs->erase(TV); } } }; bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) { if (A->isAbsent() && B->isAbsent()) { return true; } if (B->isAbsent()) { std::swap(A, B); DidSwap = !DidSwap; } if (A->isAbsent()) { auto& Present = B->asPresent(); C.DE.add(CurrentFieldName, C.solveType(getLeft()), LeftPath, getSource()); return false; } auto& Present1 = A->asPresent(); auto& Present2 = B->asPresent(); return unify(Present1.Ty, Present2.Ty, DidSwap); }; bool Unifier::unify(Type* A, Type* B, bool DidSwap) { A = A->find(); B = B->find(); auto unifyError = [&]() { C.DE.add( Left, Right, LeftPath, RightPath, Source ); }; auto pushLeft = [&](TypeIndex I) { if (DidSwap) { RightPath.push_back(I); } else { LeftPath.push_back(I); } }; auto popLeft = [&]() { if (DidSwap) { RightPath.pop_back(); } else { LeftPath.pop_back(); } }; auto pushRight = [&](TypeIndex I) { if (DidSwap) { LeftPath.push_back(I); } else { RightPath.push_back(I); } }; auto popRight = [&]() { if (DidSwap) { LeftPath.pop_back(); } else { RightPath.pop_back(); } }; auto swap = [&]() { std::swap(A, B); DidSwap = !DidSwap; }; if (A->isVar() && B->isVar()) { auto& Var1 = A->asVar(); auto& Var2 = B->asVar(); if (Var1.isRigid() && Var2.isRigid()) { if (Var1.Id != Var2.Id) { unifyError(); return false; } return true; } Type* To; Type* From; if (Var1.isRigid() && Var2.isUni()) { To = A; From = B; } else { // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification // Either way, Var1, being Unification, is a good candidate for being unified away To = B; From = A; } if (From->asVar().Id != To->asVar().Id) { join(From, To); } return true; } if (B->isVar()) { swap(); } if (A->isVar()) { auto& TV = A->asVar(); // Rigid type variables can never unify with antything else than what we // have already handled in the previous if-statement, so issue an error. if (TV.isRigid()) { unifyError(); return false; } // Occurs check if (B->hasTypeVar(A)) { // NOTE Just like GHC, we just display an error message indicating that // A cannot match B, e.g. a cannot match [a]. It looks much better // than obsure references to an occurs check unifyError(); return false; } join(A, B); return true; } if (A->isArrow() && B->isArrow()) { auto& Arrow1 = A->asArrow(); auto& Arrow2 = B->asArrow(); bool Success = true; LeftPath.push_back(TypeIndex::forArrowParamType()); RightPath.push_back(TypeIndex::forArrowParamType()); if (!unify(Arrow1.ParamType, Arrow2.ParamType, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); LeftPath.push_back(TypeIndex::forArrowReturnType()); RightPath.push_back(TypeIndex::forArrowReturnType()); if (!unify(Arrow1.ReturnType, Arrow2.ReturnType, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); return Success; } if (A->isApp() && B->isApp()) { auto& App1 = A->asApp(); auto& App2 = B->asApp(); bool Success = true; LeftPath.push_back(TypeIndex::forAppOpType()); RightPath.push_back(TypeIndex::forAppOpType()); if (!unify(App1.Op, App2.Op, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); LeftPath.push_back(TypeIndex::forAppArgType()); RightPath.push_back(TypeIndex::forAppArgType()); if (!unify(App1.Arg, App2.Arg, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); return Success; } if (A->isTuple() && B->isTuple()) { auto& Tuple1 = A->asTuple(); auto& Tuple2 = B->asTuple(); if (Tuple1.ElementTypes.size() != Tuple2.ElementTypes.size()) { unifyError(); return false; } auto Count = Tuple1.ElementTypes.size(); bool Success = true; for (size_t I = 0; I < Count; I++) { LeftPath.push_back(TypeIndex::forTupleElement(I)); RightPath.push_back(TypeIndex::forTupleElement(I)); if (!unify(Tuple1.ElementTypes[I], Tuple2.ElementTypes[I], DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); } return Success; } // if (A->isTupleIndex() || B->isTupleIndex()) { // // Type(s) could not be simplified at the beginning of this function, // // so we have to re-visit the constraint when there is more information. // C.Queue.push_back(Constraint); // return true; // } // This does not work because it ignores the indices // if (A->isTupleIndex() && B->isTupleIndex()) { // auto Index1 = static_cast(A); // auto Index2 = static_cast(B); // return unify(Index1->Ty, Index2->Ty, Source); // } if (A->isCon() && B->isCon()) { auto& Con1 = A->asCon(); auto& Con2 = B->asCon(); if (Con1.Id != Con2.Id) { unifyError(); return false; } return true; } if (A->isNil() && B->isNil()) { return true; } if (A->isField() && B->isField()) { auto& Field1 = A->asField(); auto& Field2 = B->asField(); bool Success = true; if (Field1.Name == Field2.Name) { LeftPath.push_back(TypeIndex::forFieldType()); RightPath.push_back(TypeIndex::forFieldType()); CurrentFieldName = Field1.Name; if (!unifyField(Field1.Ty, Field2.Ty, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); LeftPath.push_back(TypeIndex::forFieldRest()); RightPath.push_back(TypeIndex::forFieldRest()); if (!unify(Field1.RestTy, Field2.RestTy, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); return Success; } auto NewRestTy = new Type(TVar(VarKind::Unification, C.NextTypeVarId++)); pushLeft(TypeIndex::forFieldRest()); if (!unify(Field1.RestTy, new Type(TField(Field2.Name, Field2.Ty, NewRestTy)), DidSwap)) { Success = false; } popLeft(); pushRight(TypeIndex::forFieldRest()); if (!unify(new Type(TField(Field1.Name, Field1.Ty, NewRestTy)), Field2.RestTy, DidSwap)) { Success = false; } popRight(); return Success; } if (A->isNil() && B->isField()) { swap(); } if (A->isField() && B->isNil()) { auto& Field = A->asField(); bool Success = true; pushLeft(TypeIndex::forFieldType()); CurrentFieldName = Field.Name; if (!unifyField(Field.Ty, new Type(TAbsent()), DidSwap)) { Success = false; } popLeft(); pushLeft(TypeIndex::forFieldRest()); if (!unify(Field.RestTy, B, DidSwap)) { Success = false; } popLeft(); return Success; } unifyError(); return false; } bool Checker::unify(Type* Left, Type* Right, Node* Source) { // std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; Unifier A { *this, Left, Right, Source }; A.unify(); return A.DidJoin; } }