From 693301bac6ea7bd84beacc37c880c239eb5b16b2 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Thu, 1 Jun 2023 16:22:37 +0200 Subject: [PATCH] Improve support for type classes and simplify algorithm --- include/bolt/CST.hpp | 4 + include/bolt/Diagnostics.hpp | 8 +- include/bolt/Type.hpp | 37 +++-- src/Checker.cc | 287 ++++++++++++----------------------- src/Diagnostics.cc | 2 + src/Parser.cc | 10 +- 6 files changed, 138 insertions(+), 210 deletions(-) diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 1c70ce2f1..0b2efc6b0 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -1765,6 +1765,10 @@ namespace bolt { return Parent->getKind() == NodeKind::ClassDeclaration; } + bool isSignature() const noexcept { + return !Body; + } + Token* getFirstToken() const override; Token* getLastToken() const override; diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index 0ba01ead0..888913636 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -109,9 +109,9 @@ namespace bolt { public: TypeclassSignature Sig; - FunctionDeclaration* Decl; + Node* Decl; - inline TypeclassMissingDiagnostic(TypeclassSignature Sig, FunctionDeclaration* Decl): + inline TypeclassMissingDiagnostic(TypeclassSignature Sig, Node* Decl): Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {} inline Node* getNode() const override { @@ -124,10 +124,10 @@ namespace bolt { public: ByteString TypeclassName; - TCon* Ty; + Type* Ty; Node* Source; - inline InstanceNotFoundDiagnostic(ByteString TypeclassName, TCon* Ty, Node* Source): + inline InstanceNotFoundDiagnostic(ByteString TypeclassName, Type* Ty, Node* Source): Diagnostic(DiagnosticKind::InstanceNotFound), TypeclassName(TypeclassName), Ty(Ty), Source(Source) {} inline Node* getNode() const override { diff --git a/include/bolt/Type.hpp b/include/bolt/Type.hpp index cc7f8c257..67c1b756f 100644 --- a/include/bolt/Type.hpp +++ b/include/bolt/Type.hpp @@ -13,6 +13,7 @@ namespace bolt { class Type; class TVar; + class TCon; using TypeclassId = ByteString; @@ -29,6 +30,12 @@ namespace bolt { }; + struct TypeSig { + Type* Orig; + Type* Op; + std::vector Args; + }; + enum class TypeIndexKind { AppOpType, AppArgType, @@ -267,6 +274,10 @@ namespace bolt { return VK; } + inline bool isRigid() const noexcept { + return VK == VarKind::Rigid; + } + Type* find(); void set(Type* Ty); @@ -282,6 +293,8 @@ namespace bolt { ByteString Name; + TypeclassContext Provided; + inline TVarRigid(size_t Id, ByteString Name): TVar(Id, VarKind::Rigid), Name(Name) {} @@ -405,44 +418,48 @@ namespace bolt { virtual void enterType(C* Ty) {} virtual void exitType(C* Ty) {} - virtual void visitVarType(C* Ty) { + virtual void visitType(C* Ty) { visitEachChild(Ty); } + virtual void visitVarType(C* Ty) { + visitType(Ty); + } + virtual void visitAppType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } virtual void visitPresentType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } virtual void visitConType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } virtual void visitArrowType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } virtual void visitTupleType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } virtual void visitTupleIndexType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } virtual void visitAbsentType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } virtual void visitFieldType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } virtual void visitNilType(C* Ty) { - visitEachChild(Ty); + visitType(Ty); } public: diff --git a/src/Checker.cc b/src/Checker.cc index 76c120992..cbb4bd3ea 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -20,6 +20,8 @@ // TODO create the constraint in addConstraint, not the other way round +// TODO Find a way to create a match expression in between ( and ) + #include #include #include @@ -271,11 +273,11 @@ namespace bolt { case NodeKind::ClassDeclaration: { auto Class = static_cast(X); - for (auto TE: Class->TypeVars) { - auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText()); - TV->Contexts.emplace(Class->Name->getCanonicalText()); - TE->setType(TV); - } + // for (auto TE: Class->TypeVars) { + // auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText()); + // // TV->Contexts.emplace(Class->Name->getCanonicalText()); + // TE->setType(TV); + // } for (auto Element: Class->Elements) { forwardDeclare(Element); } @@ -459,17 +461,26 @@ namespace bolt { setContext(Let->Ctx); + auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { + auto Id = Class->Name->getCanonicalText(); + auto Ctx = &getContext(); + std::vector Out; + for (auto TE: Class->TypeVars) { + auto Name = TE->Name->getCanonicalText(); + auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); + TV->Contexts.emplace(Id); + Ctx->Env.emplace(Name, new Forall(TV)); + Out.push_back(TV); + } + return Out; + }; + // If declaring a let-declaration inside a type class declaration, // we need to mark that the let-declaration requires this class. // This marking is set on the rigid type variables of the class, which // are then added to this local type environment. if (Let->isClass()) { - auto Class = static_cast(Let->Parent); - for (auto TE: Class->TypeVars) { - auto TV = llvm::cast(TE->getType()); - Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV)); - Let->Ctx->TVs->emplace(TV); - } + addClassVars(static_cast(Let->Parent), true); } // Here we infer the primary type of the let declaration. If there's a @@ -492,27 +503,22 @@ namespace bolt { auto Instance = static_cast(Let->Parent); auto Class = llvm::cast(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class)); + auto SigLet = llvm::cast(Class->getScope()->lookupDirect({ {}, Let->Name->getCanonicalText() }, SymbolKind::Var)); + + auto Params = addClassVars(Class, false); // The type asserts in the type class declaration might make use of // the type parameters of the type class declaration, so it is // important to make them available in the type environment. Moreover, // we will be unifying them with the actual types declared in the // instance declaration, so we keep track of them. - std::vector Params; - TVSub Sub; - for (auto TE: Class->TypeVars) { - auto TV = createTypeVar(); - Sub.emplace(llvm::cast(TE->getType()), TV); - Params.push_back(TV); - } - - auto SigLet = llvm::cast(Class->getScope()->lookupDirect({ {}, Let->Name->getCanonicalText() }, SymbolKind::Var)); - - // It would be very strange if there was no type assert in the type - // class let-declaration but we rather not let the compiler crash if that happens. - if (SigLet->TypeAssert) { - addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression)->substitute(Sub), Let)); - } + // std::vector Params; + // TVSub Sub; + // for (auto TE: Class->TypeVars) { + // auto TV = createTypeVar(); + // Sub.emplace(llvm::cast(TE->getType()), TV); + // Params.push_back(TV); + // } // Here we do the actual unification of e.g. Eq a with Eq Bool. The // unification variables we created previously will be unified with @@ -522,6 +528,16 @@ namespace bolt { addConstraint(new CEqual(Param, TE->getType(), TE)); } + // It would be very strange if there was no type assert in the type + // class let-declaration but we rather not let the compiler crash if that happens. + if (SigLet->TypeAssert) { + // Note that we can't do SigLet->TypeAssert->TypeExpression->getType() + // because we need to re-generate the type within the local context of + // this let-declaration. + // TODO make CEqual accept multiple nodes + addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let)); + } + } if (Let->Body) { @@ -542,12 +558,18 @@ namespace bolt { } } - Let->Ctx->Parent->Env.emplace(Let->Name->getCanonicalText(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty)); + if (!Let->isInstance()) { + Let->Ctx->Parent->Env.emplace(Let->Name->getCanonicalText(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty)); + } } void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) { + if (Decl->isSignature()) { + return; + } + setContext(Decl->Ctx); std::vector ParamTypes; @@ -770,7 +792,9 @@ namespace bolt { auto D = static_cast(C); std::vector Types; for (auto TE: D->TEs) { - Types.push_back(inferTypeExpression(TE)); + auto TV = static_cast(inferTypeExpression(TE)); + TV->Provided.emplace(D->Name->getCanonicalText()); + Types.push_back(TV); } return new CClass(D->Name->getCanonicalText(), Types); } @@ -1175,136 +1199,6 @@ namespace bolt { } - void Checker::checkTypeclassSigs(Node* N) { - - struct LetVisitor : CSTVisitor { - - Checker& C; - - void visitLetDeclaration(FunctionDeclaration* Decl) { - - // Only inspect those let-declarations that look like a function - if (Decl->Params.empty()) { - return; - } - - // Will contain the type classes that were specified in the type assertion by the user. - // There might be some other signatures as well, but those are an implementation detail. - std::vector Expected; - - // We must add the type class itself to Expected because in order for - // propagation to work the rigid type variables expect this class to be - // present even inside the current class. By adding it to Expected, we - // are effectively cancelling out the default behavior of requiring the - // presence of this type classes. - if (llvm::isa(Decl->Parent)) { - auto Class = llvm::cast(Decl->Parent); - std::vector Tys; - for (auto TE : Class->TypeVars) { - Tys.push_back(llvm::cast(TE->getType())); - } - Expected.push_back( - TypeclassSignature{Class->Name->getCanonicalText(), Tys}); - } - - // Here we scan the type signature for type classes that user expects to be there. - if (Decl->TypeAssert != nullptr) { - if (llvm::isa(Decl->TypeAssert->TypeExpression)) { - auto QTE = static_cast(Decl->TypeAssert->TypeExpression); - for (auto [C, Comma]: QTE->Constraints) { - if (llvm::isa(C)) { - auto TCE = static_cast(C); - std::vector Tys; - for (auto TE: TCE->TEs) { - auto TV = TE->getType(); - ZEN_ASSERT(llvm::isa(TV)); - Tys.push_back(static_cast(TV)); - } - Expected.push_back(TypeclassSignature { TCE->Name->getCanonicalText(), Tys }); - } - } - } - } - - // Sort them lexically and remove any duplicates - std::sort(Expected.begin(), Expected.end()); - Expected.erase(std::unique(Expected.begin(), Expected.end()), Expected.end()); - - // Will contain the type class signatures that our program inferred that - // at the very least should be present to make the body work. - std::vector Actual; - - // This is ugly but it works. Scan all type variables local to this - // declaration and add the classes that they require to Actual. - for (auto Ty: *Decl->Ctx->TVs) { - auto S = Ty->solve(); - if (llvm::isa(S)) { - auto TV = static_cast(S); - for (auto Class: TV->Contexts) { - Actual.push_back(TypeclassSignature { Class, { TV } }); - } - } - } - - // Sort them lexically and remove any duplicates - std::sort(Actual.begin(), Actual.end()); - Actual.erase(std::unique(Actual.begin(), Actual.end()), Actual.end()); - - auto ActualIter = Actual.begin(); - auto ExpectedIter = Expected.begin(); - - for (; ActualIter != Actual.end() || ExpectedIter != Expected.end() ;) { - - // Our program inferred no more type classes that should be present, - // yet Expected still did find a few that the user declared in a - // signature. No errors should be reported, and we can quit this loop. - if (ActualIter == Actual.end()) { - // TODO Maybe issue a warning that a type class went unused - break; - } - - // There are no more type classes that were expected, so any remaining - // type classes in Actual will not have a corresponding signature. - // This should be reported as an error. - if (ExpectedIter == Expected.end()) { - for (; ActualIter != Actual.end(); ActualIter++) { - C.DE.add(*ActualIter, Decl); - } - break; - } - - // If ExpectedIter is already at Show, but ActualIter is still at Eq, - // then we clearly missed the Eq in ExpectedIter. This clearly is an - // error, since the user missed something in a type signature. - if (*ActualIter < *ExpectedIter) { - C.DE.add(*ActualIter, Decl); - ActualIter++; - continue; - } - - // If ActualIter is Show but ExpectedIter is still Eq, then the user - // specified too much type classes in a type signature. This is no error, - // but it might be worthwhile to issue a warning. - if (*ExpectedIter < *ActualIter) { - // DE.add(It2->Name, Decl); - ExpectedIter++; - continue; - } - - // Both type class signatures are equal, cancelling each other out. - ActualIter++; - ExpectedIter++; - } - - } - - }; - - LetVisitor V { {}, *this }; - V.visit(N); - - } - Type* Checker::getType(TypedNode *Node) { return Node->getType()->solve(); } @@ -1368,7 +1262,6 @@ namespace bolt { ActiveContext = nullptr; solve(new CMany(*SF->Ctx->Constraints)); - checkTypeclassSigs(SF); } void Checker::solve(Constraint* Constraint) { @@ -1419,15 +1312,9 @@ namespace bolt { if (Con1->Id != Con2-> Id) { return false; } - // TODO must handle a TApp - // ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); - // for (auto [T1, T2]: zen::zip(Con1->Args, Con2->Args)) { - // if (!assignableTo(T1, T2)) { - // return false; - // } - // } return true; } + // TODO must handle a TApp ZEN_UNREACHABLE } @@ -1504,52 +1391,78 @@ namespace bolt { return unify(Constraint->Left, Constraint->Right, false); } - std::vector findInstanceContext(TCon* Ty, TypeclassId& Class) { + std::vector findInstanceContext(const TypeSig& Ty, TypeclassId& Class) { auto Match = C.InstanceMap.find(Class); std::vector S; if (Match != C.InstanceMap.end()) { for (auto Instance: Match->second) { - if (assignableTo(Ty, Instance->TypeExps[0]->getType())) { + if (assignableTo(Ty.Orig, Instance->TypeExps[0]->getType())) { std::vector S; - // TODO handle TApp - // for (auto Arg: Ty->Args) { - // TypeclassContext Classes; - // // TODO - // S.push_back(Classes); - // } + for (auto Arg: Ty.Args) { + TypeclassContext Classes; + // TODO + S.push_back(Classes); + } return S; } } } - C.DE.add(Class, Ty, getSource()); - // TODO handle TApp - // for (auto Arg: Ty->Args) { - // S.push_back({}); - // } + C.DE.add(Class, Ty.Orig, getSource()); + for (auto Arg: Ty.Args) { + S.push_back({}); + } return S; } + TypeSig getTypeSig(Type* Ty) { + struct Visitor : TypeVisitor { + Type* Op = nullptr; + std::vector Args; + void visitType(Type* Ty) override { + if (!Op) { + Op = Ty; + } else { + Args.push_back(Ty); + } + } + void visitAppType(TApp* Ty) override { + visitEachChild(Ty); + } + }; + Visitor V; + V.visit(Ty); + return TypeSig { Ty, V.Op, V.Args }; + } + void propagateClasses(std::unordered_set& Classes, Type* Ty) { if (llvm::isa(Ty)) { auto TV = llvm::cast(Ty); for (auto Class: Classes) { TV->Contexts.emplace(Class); } - } else if (llvm::isa(Ty)) { + if (TV->isRigid()) { + auto RV = static_cast(Ty); + for (auto Id: RV->Contexts) { + if (!RV->Provided.count(Id)) { + C.DE.add(TypeclassSignature { Id, { RV } }, getSource()); + } + } + } + } else if (llvm::isa(Ty) || llvm::isa(Ty)) { + auto Sig = getTypeSig(Ty); for (auto Class: Classes) { - propagateClassTycon(Class, llvm::cast(Ty)); + propagateClassTycon(Class, Sig); } } else if (!Classes.empty()) { C.DE.add(Ty, std::vector(Classes.begin(), Classes.end()), getSource()); } }; - void propagateClassTycon(TypeclassId& Class, TCon* Ty) { - auto S = findInstanceContext(Ty, Class); - // TODO handle TApp - // for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) { - // propagateClasses(Classes, Arg); - // } + void propagateClassTycon(TypeclassId& Class, const TypeSig& Sig) { + auto S = findInstanceContext(Sig, Class); + for (auto [Classes, Arg]: zen::zip(S, Sig.Args)) { + propagateClasses(Classes, Arg); + } }; /** diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index ae83291da..b9a975d3a 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -110,6 +110,8 @@ namespace bolt { return "'return'"; case NodeKind::TypeKeyword: return "'type'"; + case NodeKind::ReferenceTypeExpression: + return "a type reference"; case NodeKind::FunctionDeclaration: return "a function declaration"; case NodeKind::VariableDeclaration: diff --git a/src/Parser.cc b/src/Parser.cc index f41c107d2..2e57f8b16 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -1029,7 +1029,6 @@ finish: PubKeyword* Pub = nullptr; FnKeyword* Fn; - MutKeyword* Mut = nullptr; TypeAssert* TA = nullptr; LetBody* Body = nullptr; @@ -1047,11 +1046,6 @@ finish: return nullptr; } Fn = static_cast(T0); - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::MutKeyword) { - Mut = static_cast(T1); - Tokens.get(); - } auto Name = expectToken(); if (!Name) { @@ -1059,9 +1053,6 @@ finish: Pub->unref(); } Fn->unref(); - if (Mut) { - Mut->unref(); - } skipToLineFoldEnd(); return nullptr; } @@ -1146,6 +1137,7 @@ after_params: checkLineFoldEnd(); finish: + return new FunctionDeclaration( Pub, Fn,