diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index 43b2682d2..d551ff24d 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace bolt { @@ -207,6 +208,16 @@ namespace bolt { TVSub Solution; + /** + * The queue that is used during solving to store any unsolved constraints. + */ + std::deque Queue; + + /** + * Pointer to the current constraint being unified. + */ + CEqual* C; + std::vector Contexts; InferContext& getContext(); @@ -265,6 +276,8 @@ namespace bolt { Type* simplify(Type* Ty); + Type* find(Type* Ty); + /** * Assign a type to a unification variable. * diff --git a/src/Checker.cc b/src/Checker.cc index 7b24fb051..231573889 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -7,6 +7,8 @@ // TODO Fix TVSub to use TVar.Id instead of the pointer address +// Optimise constraint solving by solving some constraints during inference + #include #include #include @@ -124,32 +126,33 @@ namespace bolt { case ConstraintKind::Equal: { auto Y = static_cast(C); - std::size_t MaxLevel = 0; - for (std::size_t I = Contexts.size(); I-- > 0; ) { - auto Ctx = Contexts[I]; - if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { - MaxLevel = I; - break; - } - } - std::size_t MinLevel = MaxLevel; - for (std::size_t I = 0; I < Contexts.size(); I++) { - auto Ctx = Contexts[I]; - if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { - MinLevel = I; - break; - } - } - if (MaxLevel == MinLevel) { - solveCEqual(Y); - } else { - Contexts[MaxLevel]->Constraints->push_back(C); - } - // Contexts.front()->Constraints->push_back(C); - //auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth); - //ZEN_ASSERT(I < Contexts.size()); - //auto Ctx = Contexts[I]; - //Ctx->Constraints.push_back(Constraint); + +// std::size_t MaxLevel = 0; +// for (std::size_t I = Contexts.size(); I-- > 0; ) { +// auto Ctx = Contexts[I]; +// if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { +// MaxLevel = I; +// break; +// } +// } + +// std::size_t MinLevel = MaxLevel; +// for (std::size_t I = 0; I < Contexts.size(); I++) { +// auto Ctx = Contexts[I]; +// if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { +// MinLevel = I; +// break; +// } +// } + +// // TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0 +// if (MaxLevel == MinLevel) { +// solveCEqual(Y); +// } else { +// Contexts[MaxLevel]->Constraints->push_back(C); +// } + + Contexts.back()->Constraints->push_back(C); break; } case ConstraintKind::Many: @@ -232,29 +235,31 @@ namespace bolt { { auto Let = static_cast(X); - auto NewCtx = createInferContext(); - Let->Ctx = NewCtx; + if (!Let->Params.empty()) { - Contexts.push_back(NewCtx); + Let->Ctx = createInferContext(); + Contexts.push_back(Let->Ctx); - // If declaring a let-declaration inside a type class declaration, - // we need to mark that the let-declaration requires this class. - // This marking is set on the rigid type variables of the class, which - // are then added to this local type environment. - if (llvm::isa(Let->Parent)) { - auto Decl = static_cast(Let->Parent); - for (auto TE: Decl->TypeVars) { - auto TV = llvm::cast(TE->getType()); - NewCtx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV)); - NewCtx->TVs->emplace(TV); + // If declaring a let-declaration inside a type class declaration, + // we need to mark that the let-declaration requires this class. + // This marking is set on the rigid type variables of the class, which + // are then added to this local type environment. + if (llvm::isa(Let->Parent)) { + auto Decl = static_cast(Let->Parent); + for (auto TE: Decl->TypeVars) { + auto TV = llvm::cast(TE->getType()); + Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV)); + Let->Ctx->TVs->emplace(TV); + } } + } // Here we infer the primary type of the let declaration. If there's a // type assert, that assert should be authoritative so we use that. // Otherwise, the type is not further specified and we create a new // unification variable. - Type *Ty; + Type* Ty; if (Let->TypeAssert) { Ty = inferTypeExpression(Let->TypeAssert->TypeExpression); } else { @@ -276,7 +281,7 @@ namespace bolt { std::vector Params; for (auto TE: Class->TypeVars) { auto TV = createTypeVar(); - NewCtx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV)); + Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV)); Params.push_back(TV); } @@ -304,7 +309,9 @@ namespace bolt { case NodeKind::LetBlockBody: { auto Block = static_cast(Let->Body); - NewCtx->ReturnType = createTypeVar(); + if (!Let->Params.empty()) { + Let->Ctx->ReturnType = createTypeVar(); + } for (auto Element: Block->Elements) { forwardDeclare(Element); } @@ -315,9 +322,15 @@ namespace bolt { } } - Contexts.pop_back(); + if (!Let->Params.empty()) { + Contexts.pop_back(); + } - inferBindings(Let->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs); + if (Let->Params.empty()) { + inferBindings(Let->Pattern, Ty); + } else { + inferBindings(Let->Pattern, Ty, Let->Ctx->Constraints, Let->Ctx->TVs); + } break; } @@ -380,8 +393,9 @@ namespace bolt { { auto Decl = static_cast(N); - auto NewCtx = Decl->Ctx; - Contexts.push_back(NewCtx); + if (!Decl->Params.empty()) { + Contexts.push_back(Decl->Ctx); + } std::vector ParamTypes; Type* RetType; @@ -425,7 +439,9 @@ namespace bolt { addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), N }); } - Contexts.pop_back(); + if (!Decl->Params.empty()) { + Contexts.pop_back(); + } break; } @@ -783,6 +799,11 @@ namespace bolt { void visitLetDeclaration(LetDeclaration* Decl) { + // Only inspect those let-declarations that look like a function + if (Decl->Params.empty()) { + return; + } + // Will contain the type classes that were specified in the type assertion by the user. // There might be some other signatures as well, but those are an implementation detail. std::vector Expected; @@ -923,14 +944,12 @@ namespace bolt { void Checker::solve(Constraint* Constraint, TVSub& Solution) { - std::stack Queue; - Queue.push(Constraint); + Queue.push_back(Constraint); while (!Queue.empty()) { - auto Constraint = Queue.top(); - - Queue.pop(); + auto Constraint = Queue.front(); + Queue.pop_front(); switch (Constraint->getKind()) { @@ -947,7 +966,7 @@ namespace bolt { { auto Many = static_cast(Constraint); for (auto Constraint: Many->Elements) { - Queue.push(Constraint); + Queue.push_back(Constraint); } break; } @@ -1028,18 +1047,16 @@ namespace bolt { }; void Checker::solveCEqual(CEqual* C) { - /* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */ + // std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; OrigLeft = C->Left; OrigRight = C->Right; Source = C->Source; unify(C->Left, C->Right); LeftPath = {}; RightPath = {}; - /* DE.add(simplify(C->Left), simplify(C->Right), C->Source); */ } - Type* Checker::simplify(Type* Ty) { - + Type* Checker::find(Type* Ty) { while (Ty->getKind() == TypeKind::Var) { auto Match = Solution.find(static_cast(Ty)); if (Match == Solution.end()) { @@ -1047,6 +1064,12 @@ namespace bolt { } Ty = Match->second; } + return Ty; + } + + Type* Checker::simplify(Type* Ty) { + + Ty = find(Ty); switch (Ty->getKind()) { @@ -1136,11 +1159,11 @@ namespace bolt { // type such as Int. We therefore never want the variable to be polymorphic // and be instantiated with a fresh variable, as it has already been solved. // Should it get assigned another unification variable, that's OK too - // because then the context of that variable is what matters and not anymore - // the context of this one. - if (!Contexts.empty()) { - Contexts.back()->TVs->erase(TV); - } + // because then the inference context of that variable is what matters and + // not anymore the context of this one. + // if (!Contexts.empty()) { + // Contexts.back()->TVs->erase(TV); + // } } @@ -1197,8 +1220,6 @@ namespace bolt { }; - - bool Checker::unify(Type* A, Type* B) { A = simplify(A); @@ -1322,6 +1343,11 @@ namespace bolt { return Success; } + if (llvm::isa(A) || llvm::isa(B)) { + Queue.push_back(C); + return true; + } + // if (llvm::isa(A) && llvm::isa(B)) { // auto Index1 = static_cast(A); // auto Index2 = static_cast(B); diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index d26a4c485..f4f49ee4a 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -1,4 +1,6 @@ +// FIXME writeExcerpt does not work well with the last line in a file + #include #include