From a8f8658f27dc29bf7bc092516d81978873253764 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Tue, 30 May 2023 13:37:47 +0200 Subject: [PATCH] Make InferContext have a parent context --- include/bolt/CST.hpp | 27 ++++- include/bolt/Checker.hpp | 39 ++++--- src/Checker.cc | 223 ++++++++++++++++++++++++--------------- 3 files changed, 180 insertions(+), 109 deletions(-) diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 36da0658b..8bf36c8bd 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -17,6 +17,7 @@ namespace bolt { class Type; + class InferContext; class Token; class SourceFile; @@ -1279,6 +1280,8 @@ namespace bolt { class MatchCase : public Node { public: + InferContext* Ctx; + class Pattern* Pattern; class RArrowAlt* RArrowAlt; class Expression* Expression; @@ -1658,9 +1661,6 @@ namespace bolt { }; - class Type; - class InferContext; - class LetDeclaration : public Node { Scope* TheScope = nullptr; @@ -1703,6 +1703,22 @@ namespace bolt { return TheScope; } + bool isFunc() const noexcept { + return !Params.empty(); + } + + bool isVar() const noexcept { + return !isFunc(); + } + + bool isInstance() const noexcept { + return Parent->getKind() == NodeKind::InstanceDeclaration; + } + + bool isClass() const noexcept { + return Parent->getKind() == NodeKind::ClassDeclaration; + } + Token* getFirstToken() const override; Token* getLastToken() const override; @@ -1801,6 +1817,8 @@ namespace bolt { class RecordDeclaration : public Node { public: + InferContext* Ctx; + class PubKeyword* PubKeyword; class StructKeyword* StructKeyword; IdentifierAlt* Name; @@ -1878,6 +1896,8 @@ namespace bolt { class VariantDeclaration : public Node { public: + InferContext* Ctx; + class PubKeyword* PubKeyword; class EnumKeyword* EnumKeyword; class IdentifierAlt* Name; @@ -1912,6 +1932,7 @@ namespace bolt { public: TextFile& File; + InferContext* Ctx; std::vector Elements; diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index d4da9385a..fe4a13eb6 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -157,10 +157,8 @@ namespace bolt { TypeEnv Env; Type* ReturnType = nullptr; - std::vector Classes; - //inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env, Type* ReturnType): - // Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env), ReturnType(ReturnType) {} + InferContext* Parent = nullptr; }; @@ -183,22 +181,20 @@ namespace bolt { std::unordered_map> InstanceMap; - std::vector Contexts; + // std::vector Contexts; + + InferContext* ActiveContext; + + InferContext& getContext(); + void pushContext(InferContext* Ctx); + void popContext(); /** * The queue that is used during solving to store any unsolved constraints. */ std::deque Queue; - /** - * Pointer to the current constraint being unified. - */ - CEqual* C; - - InferContext& getContext(); - void addConstraint(Constraint* Constraint); - void addClass(TypeclassSignature Sig); void forwardDeclare(Node* Node); void forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); @@ -217,12 +213,18 @@ namespace bolt { TCon* createConType(ByteString Name); TVar* createTypeVar(); TVarRigid* createRigidVar(ByteString Name); - InferContext* createInferContext(TVSet* TVs = new TVSet, ConstraintSet* Constraints = new ConstraintSet); + InferContext* createInferContext( + InferContext* Parent = nullptr, + TVSet* TVs = new TVSet, + ConstraintSet* Constraints = new ConstraintSet + ); void addBinding(ByteString Name, Scheme* Scm); Scheme* lookup(ByteString Name); + void initialize(Node* N); + /** * Looks up a type/variable and ensures that it is a monomorphic type. * @@ -248,6 +250,9 @@ namespace bolt { void propagateClasses(TypeclassContext& Classes, Type* Ty); void propagateClassTycon(TypeclassId& Class, TCon* Ty); + // TODO Remove this + Node* Source; + /** * Assign a type to a unification variable. * @@ -260,14 +265,6 @@ namespace bolt { */ void join(TVar* A, Type* B); - // Unification parameters - Type* OrigLeft; - Type* OrigRight; - TypePath LeftPath; - TypePath RightPath; - ByteString CurrentFieldName; - Node* Source; - bool unify(Type* A, Type* B); void unifyError(); diff --git a/src/Checker.cc b/src/Checker.cc index 4e789bc03..594753c2f 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -98,12 +98,16 @@ namespace bolt { } Scheme* Checker::lookup(ByteString Name) { - for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) { - auto Curr = *Iter; + auto Curr = &getContext(); + for (;;) { auto Match = Curr->Env.find(Name); if (Match != Curr->Env.end()) { return Match->second; } + Curr = Curr->Parent; + if (!Curr) { + break; + } } return nullptr; } @@ -119,11 +123,11 @@ namespace bolt { } void Checker::addBinding(ByteString Name, Scheme* Scm) { - Contexts.back()->Env.emplace(Name, Scm); + getContext().Env.emplace(Name, Scm); } Type* Checker::getReturnType() { - auto Ty = Contexts.back()->ReturnType; + auto Ty = getContext().ReturnType; ZEN_ASSERT(Ty != nullptr); return Ty; } @@ -137,24 +141,43 @@ namespace bolt { return false; } + void Checker::pushContext(InferContext* Ctx) { + ActiveContext = Ctx; + } + + void Checker::popContext() { + ZEN_ASSERT(ActiveContext); + ActiveContext = ActiveContext->Parent; + } + InferContext& Checker::getContext() { - ZEN_ASSERT(!Contexts.empty()); - return *Contexts.back(); + ZEN_ASSERT(ActiveContext); + return *ActiveContext; } void Checker::addConstraint(Constraint* C) { switch (C->getKind()) { case ConstraintKind::Class: { - Contexts.back()->Constraints->push_back(C); + getContext().Constraints->push_back(C); break; } case ConstraintKind::Equal: { auto Y = static_cast(C); + auto Curr = &getContext(); + std::vector Contexts; + for (;;) { + Contexts.push_back(Curr); + Curr = Curr->Parent; + if (!Curr) { + break; + } + } + std::size_t MaxLevelLeft = 0; - for (std::size_t I = Contexts.size(); I-- > 0; ) { + for (std::size_t I = 0; I < Contexts.size(); I++) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Left)) { MaxLevelLeft = I; @@ -162,7 +185,7 @@ namespace bolt { } } std::size_t MaxLevelRight = 0; - for (std::size_t I = Contexts.size(); I-- > 0; ) { + for (std::size_t I = 0; I < Contexts.size(); I++) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Right)) { MaxLevelRight = I; @@ -172,7 +195,7 @@ namespace bolt { auto MaxLevel = std::max(MaxLevelLeft, MaxLevelRight); std::size_t MinLevel = MaxLevel; - for (std::size_t I = 0; I < Contexts.size(); I++) { + for (std::size_t I = Contexts.size(); I-- > 0; ) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { MinLevel = I; @@ -180,7 +203,6 @@ namespace bolt { } } - // TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0 if (MaxLevel == MinLevel || MaxLevelLeft == 0 || MaxLevelRight == 0) { solveCEqual(Y); } else { @@ -202,10 +224,6 @@ namespace bolt { } } - void Checker::addClass(TypeclassSignature Sig) { - getContext().Classes.push_back(Sig); - } - void Checker::forwardDeclare(Node* X) { switch (X->getKind()) { @@ -269,21 +287,19 @@ namespace bolt { { auto Decl = static_cast(X); - auto& ParentCtx = getContext(); - auto Ctx = createInferContext(); - Contexts.push_back(Ctx); + pushContext(Decl->Ctx); std::vector Vars; for (auto TE: Decl->TVs) { auto TV = createRigidVar(TE->Name->getCanonicalText()); - Ctx->TVs->emplace(TV); + Decl->Ctx->TVs->emplace(TV); Vars.push_back(TV); } Type* Ty = createConType(Decl->Name->getCanonicalText()); // Must be added early so we can create recursive types - ParentCtx.Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty)); + Decl->Ctx->Parent->Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty)); for (auto Member: Decl->Members) { switch (Member->getKind()) { @@ -298,7 +314,7 @@ namespace bolt { for (auto Element: TupleMember->Elements) { ParamTypes.push_back(inferTypeExpression(Element)); } - ParentCtx.Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Ctx->TVs, Ctx->Constraints, new TArrow(ParamTypes, RetTy))); + Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(ParamTypes, RetTy))); break; } case NodeKind::RecordVariantDeclarationMember: @@ -311,7 +327,7 @@ namespace bolt { } } - Contexts.pop_back(); + popContext(); break; } @@ -320,13 +336,12 @@ namespace bolt { { auto Decl = static_cast(X); - auto& ParentCtx = getContext(); - auto Ctx = createInferContext(); - Contexts.push_back(Ctx); + pushContext(Decl->Ctx); + std::vector Vars; for (auto TE: Decl->Vars) { auto TV = createRigidVar(TE->Name->getCanonicalText()); - Ctx->TVs->emplace(TV); + Decl->Ctx->TVs->emplace(TV); Vars.push_back(TV); } @@ -334,9 +349,9 @@ namespace bolt { auto Ty = createConType(Name); // Must be added early so we can create recursive types - ParentCtx.Env.emplace(Name, new Forall(Ty)); + Decl->Ctx->Parent->Env.emplace(Name, new Forall(Ty)); - // Corresponds to the logic of one branch of a VaraintDeclarationMember + // Corresponds to the logic of one branch of a VariantDeclarationMember Type* FieldsTy = new TNil(); for (auto Field: Decl->Fields) { FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy); @@ -345,8 +360,8 @@ namespace bolt { for (auto TV: Vars) { RetTy = new TApp(RetTy, TV); } - Contexts.pop_back(); - addBinding(Name, new Forall(Ctx->TVs, Ctx->Constraints, new TArrow({ FieldsTy }, RetTy))); + popContext(); + addBinding(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy))); break; } @@ -358,24 +373,70 @@ namespace bolt { } + void Checker::initialize(Node* N) { + + struct Init : public CSTVisitor { + + Checker& C; + + std::stack Contexts; + + InferContext* createDerivedContext() { + return C.createInferContext(Contexts.top()); + } + + void visitVariantDeclaration(VariantDeclaration* Decl) { + Decl->Ctx = createDerivedContext(); + } + + void visitRecordDeclaration(RecordDeclaration* Decl) { + Decl->Ctx = createDerivedContext(); + } + + void visitMatchCase(MatchCase* C) { + C->Ctx = createDerivedContext(); + Contexts.push(C->Ctx); + visitEachChild(C); + Contexts.pop(); + } + + void visitSourceFile(SourceFile* SF) { + SF->Ctx = C.createInferContext(); + Contexts.push(SF->Ctx); + visitEachChild(SF); + Contexts.pop(); + } + + void visitLetDeclaration(LetDeclaration* Let) { + if (Let->isFunc()) { + Let->Ctx = createDerivedContext(); + Contexts.push(Let->Ctx); + visitEachChild(Let); + Contexts.pop(); + } else { + Let->Ctx = Contexts.top(); + visitEachChild(Let); + } + } + + }; + + Init I { {}, *this }; + I.visit(N); + + } + void Checker::forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints) { auto Let = static_cast(N); - bool IsFunc = !Let->Params.empty(); - bool IsInstance = llvm::isa(Let->Parent); - bool IsClass = llvm::isa(Let->Parent); - bool HasContext = IsFunc || IsInstance || IsClass; - if (HasContext) { - Let->Ctx = createInferContext(TVs, Constraints); - Contexts.push_back(Let->Ctx); - } + pushContext(Let->Ctx); // If declaring a let-declaration inside a type class declaration, // we need to mark that the let-declaration requires this class. // This marking is set on the rigid type variables of the class, which // are then added to this local type environment. - if (IsClass) { + if (Let->isClass()) { auto Class = static_cast(Let->Parent); for (auto TE: Class->TypeVars) { auto TV = llvm::cast(TE->getType()); @@ -400,7 +461,7 @@ namespace bolt { // we need to perform some work to make sure the type asserts of the // corresponding let-declaration in the type class declaration are // accounted for. - if (IsInstance) { + if (Let->isInstance()) { auto Instance = static_cast(Let->Parent); auto Class = llvm::cast(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class)); @@ -443,7 +504,7 @@ namespace bolt { case NodeKind::LetBlockBody: { auto Block = static_cast(Let->Body); - if (IsFunc) { + if (Let->isFunc()) { Let->Ctx->ReturnType = createTypeVar(); } for (auto Element: Block->Elements) { @@ -456,9 +517,10 @@ namespace bolt { } } + popContext(); + Type* BindTy; - if (HasContext) { - Contexts.pop_back(); + if (Let->isFunc()) { BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs); } else { BindTy = inferPattern(Let->Pattern); @@ -467,17 +529,9 @@ namespace bolt { } - void Checker::inferLetDeclaration(LetDeclaration* N) { + void Checker::inferLetDeclaration(LetDeclaration* Decl) { - auto Decl = static_cast(N); - bool IsFunc = !Decl->Params.empty(); - bool IsInstance = llvm::isa(Decl->Parent); - bool IsClass = llvm::isa(Decl->Parent); - bool HasContext = IsFunc || IsInstance || IsClass; - - if (HasContext) { - Contexts.push_back(Decl->Ctx); - } + pushContext(Decl->Ctx); std::vector ParamTypes; Type* RetType; @@ -498,7 +552,7 @@ namespace bolt { case NodeKind::LetBlockBody: { auto Block = static_cast(Decl->Body); - ZEN_ASSERT(HasContext); + ZEN_ASSERT(Decl->isFunc()); RetType = Decl->Ctx->ReturnType; for (auto Element: Block->Elements) { infer(Element); @@ -512,15 +566,13 @@ namespace bolt { RetType = createTypeVar(); } - if (HasContext) { - Contexts.pop_back(); - } + popContext(); - if (IsFunc) { - addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), N }); + if (Decl->isFunc()) { + addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl }); } else { // Declaration is a plain (typed) variable - addConstraint(new CEqual { Decl->Ty, RetType, N }); + addConstraint(new CEqual { Decl->Ty, RetType, Decl }); } } @@ -611,18 +663,19 @@ namespace bolt { TVarRigid* Checker::createRigidVar(ByteString Name) { auto TV = new TVarRigid(NextTypeVarId++, Name); - Contexts.back()->TVs->emplace(TV); + getContext().TVs->emplace(TV); return TV; } TVar* Checker::createTypeVar() { auto TV = new TVar(NextTypeVarId++, VarKind::Unification); - Contexts.back()->TVs->emplace(TV); + getContext().TVs->emplace(TV); return TV; } - InferContext* Checker::createInferContext(TVSet* TVs, ConstraintSet* Constraints) { + InferContext* Checker::createInferContext(InferContext* Parent, TVSet* TVs, ConstraintSet* Constraints) { auto Ctx = new InferContext; + Ctx->Parent = Parent; Ctx->TVs = new TVSet; Ctx->Constraints = new ConstraintSet; return Ctx; @@ -813,13 +866,12 @@ namespace bolt { } Ty = createTypeVar(); for (auto Case: Match->Cases) { - auto NewCtx = createInferContext(); - Contexts.push_back(NewCtx); + pushContext(Case->Ctx); auto PattTy = inferPattern(Case->Pattern); addConstraint(new CEqual(PattTy, ValTy, X)); auto ExprTy = inferExpression(Case->Expression); addConstraint(new CEqual(ExprTy, Ty, Case->Expression)); - Contexts.pop_back(); + popContext(); } if (!Match->Value) { Ty = new TArrow({ ValTy }, Ty); @@ -1036,20 +1088,25 @@ namespace bolt { auto Y = static_cast(N); auto Def = Y->getScope()->lookup(Y->getSymbolPath()); // Name lookup failures will be reported directly in inferExpression(). - // Parameters are clearly no let-decarations. They never have their own - // inference context, so we have to skip them. - if (Def == nullptr || Def->getKind() == NodeKind::Parameter) { + if (Def == nullptr || Def->getKind() == NodeKind::SourceFile) { return; } - ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration || Def->getKind() == NodeKind::SourceFile); - RefGraph.addEdge(Stack.top(), Def); + // This case ensures that a deeply nested structure that references a + // parameter of a parent node but is not referenced itself is correctly handled. + // Note that the edge goes from the parent let to the parameter. This is normal. + if (Def->getKind() == NodeKind::Parameter) { + RefGraph.addEdge(Stack.top(), Def->Parent); + return; + } + ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration); + if (!Stack.empty()) { + RefGraph.addEdge(Def, Stack.top()); + } } }; - RefGraph.addVertex(SF); Visitor V { {}, RefGraph }; - V.Stack.push(SF); V.visit(SF); } @@ -1189,8 +1246,8 @@ namespace bolt { } void Checker::check(SourceFile *SF) { - auto RootContext = createInferContext(); - Contexts.push_back(RootContext); + initialize(SF); + pushContext(SF->Ctx); addBinding("String", new Forall(StringType)); addBinding("Int", new Forall(IntType)); addBinding("Bool", new Forall(BoolType)); @@ -1206,9 +1263,6 @@ namespace bolt { forwardDeclare(SF); auto SCCs = RefGraph.strongconnect(); for (auto Nodes: SCCs) { - if (Nodes.size() == 1 && llvm::isa(Nodes[0])) { - continue; - } auto TVs = new TVSet; auto Constraints = new ConstraintSet; for (auto N: Nodes) { @@ -1217,9 +1271,6 @@ namespace bolt { } } for (auto Nodes: SCCs) { - if (Nodes.size() == 1 && llvm::isa(Nodes[0])) { - continue; - } for (auto N: Nodes) { auto Decl = static_cast(N); Decl->IsCycleActive = true; @@ -1234,8 +1285,8 @@ namespace bolt { } } infer(SF); - Contexts.pop_back(); - solve(new CMany(*RootContext->Constraints)); + popContext(); + solve(new CMany(*SF->Ctx->Constraints)); checkTypeclassSigs(SF); } @@ -1349,6 +1400,8 @@ namespace bolt { void Checker::join(TVar* TV, Type* Ty) { + // std::cerr << describe(TV) << " => " << describe(Ty) << std::endl; + TV->set(Ty); propagateClasses(TV->Contexts, Ty); @@ -1364,9 +1417,9 @@ namespace bolt { // Should it get assigned another unification variable, that's OK too // because then that variable is what matters and it will become the new // (possibly polymorphic) variable. - if (!Contexts.empty()) { + if (ActiveContext) { // std::cerr << "erase " << describe(TV) << std::endl; - auto TVs = Contexts.back()->TVs; + auto TVs = ActiveContext->TVs; TVs->erase(TV); }