Fix variable declarations, fix regression due to eager solving, fix unification

This commit is contained in:
Sam Vervaeck 2023-05-23 16:07:58 +02:00
parent 302823ac9b
commit b23dc84f72
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
3 changed files with 105 additions and 64 deletions

View file

@ -11,6 +11,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <deque>
namespace bolt {
@ -207,6 +208,16 @@ namespace bolt {
TVSub Solution;
/**
* The queue that is used during solving to store any unsolved constraints.
*/
std::deque<class Constraint*> Queue;
/**
* Pointer to the current constraint being unified.
*/
CEqual* C;
std::vector<InferContext*> Contexts;
InferContext& getContext();
@ -265,6 +276,8 @@ namespace bolt {
Type* simplify(Type* Ty);
Type* find(Type* Ty);
/**
* Assign a type to a unification variable.
*

View file

@ -7,6 +7,8 @@
// TODO Fix TVSub to use TVar.Id instead of the pointer address
// Optimise constraint solving by solving some constraints during inference
#include <algorithm>
#include <iterator>
#include <stack>
@ -124,32 +126,33 @@ namespace bolt {
case ConstraintKind::Equal:
{
auto Y = static_cast<CEqual*>(C);
std::size_t MaxLevel = 0;
for (std::size_t I = Contexts.size(); I-- > 0; ) {
auto Ctx = Contexts[I];
if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
MaxLevel = I;
break;
}
}
std::size_t MinLevel = MaxLevel;
for (std::size_t I = 0; I < Contexts.size(); I++) {
auto Ctx = Contexts[I];
if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
MinLevel = I;
break;
}
}
if (MaxLevel == MinLevel) {
solveCEqual(Y);
} else {
Contexts[MaxLevel]->Constraints->push_back(C);
}
// Contexts.front()->Constraints->push_back(C);
//auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth);
//ZEN_ASSERT(I < Contexts.size());
// std::size_t MaxLevel = 0;
// for (std::size_t I = Contexts.size(); I-- > 0; ) {
// auto Ctx = Contexts[I];
//Ctx->Constraints.push_back(Constraint);
// if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
// MaxLevel = I;
// break;
// }
// }
// std::size_t MinLevel = MaxLevel;
// for (std::size_t I = 0; I < Contexts.size(); I++) {
// auto Ctx = Contexts[I];
// if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
// MinLevel = I;
// break;
// }
// }
// // TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0
// if (MaxLevel == MinLevel) {
// solveCEqual(Y);
// } else {
// Contexts[MaxLevel]->Constraints->push_back(C);
// }
Contexts.back()->Constraints->push_back(C);
break;
}
case ConstraintKind::Many:
@ -232,10 +235,10 @@ namespace bolt {
{
auto Let = static_cast<LetDeclaration*>(X);
auto NewCtx = createInferContext();
Let->Ctx = NewCtx;
if (!Let->Params.empty()) {
Contexts.push_back(NewCtx);
Let->Ctx = createInferContext();
Contexts.push_back(Let->Ctx);
// If declaring a let-declaration inside a type class declaration,
// we need to mark that the let-declaration requires this class.
@ -245,11 +248,13 @@ namespace bolt {
auto Decl = static_cast<ClassDeclaration*>(Let->Parent);
for (auto TE: Decl->TypeVars) {
auto TV = llvm::cast<TVar>(TE->getType());
NewCtx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV));
NewCtx->TVs->emplace(TV);
Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV));
Let->Ctx->TVs->emplace(TV);
}
}
}
// Here we infer the primary type of the let declaration. If there's a
// type assert, that assert should be authoritative so we use that.
// Otherwise, the type is not further specified and we create a new
@ -276,7 +281,7 @@ namespace bolt {
std::vector<TVar *> Params;
for (auto TE: Class->TypeVars) {
auto TV = createTypeVar();
NewCtx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV));
Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV));
Params.push_back(TV);
}
@ -304,7 +309,9 @@ namespace bolt {
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(Let->Body);
NewCtx->ReturnType = createTypeVar();
if (!Let->Params.empty()) {
Let->Ctx->ReturnType = createTypeVar();
}
for (auto Element: Block->Elements) {
forwardDeclare(Element);
}
@ -315,9 +322,15 @@ namespace bolt {
}
}
if (!Let->Params.empty()) {
Contexts.pop_back();
}
inferBindings(Let->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs);
if (Let->Params.empty()) {
inferBindings(Let->Pattern, Ty);
} else {
inferBindings(Let->Pattern, Ty, Let->Ctx->Constraints, Let->Ctx->TVs);
}
break;
}
@ -380,8 +393,9 @@ namespace bolt {
{
auto Decl = static_cast<LetDeclaration*>(N);
auto NewCtx = Decl->Ctx;
Contexts.push_back(NewCtx);
if (!Decl->Params.empty()) {
Contexts.push_back(Decl->Ctx);
}
std::vector<Type*> ParamTypes;
Type* RetType;
@ -425,7 +439,9 @@ namespace bolt {
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), N });
}
if (!Decl->Params.empty()) {
Contexts.pop_back();
}
break;
}
@ -783,6 +799,11 @@ namespace bolt {
void visitLetDeclaration(LetDeclaration* Decl) {
// Only inspect those let-declarations that look like a function
if (Decl->Params.empty()) {
return;
}
// Will contain the type classes that were specified in the type assertion by the user.
// There might be some other signatures as well, but those are an implementation detail.
std::vector<TypeclassSignature> Expected;
@ -923,14 +944,12 @@ namespace bolt {
void Checker::solve(Constraint* Constraint, TVSub& Solution) {
std::stack<class Constraint*> Queue;
Queue.push(Constraint);
Queue.push_back(Constraint);
while (!Queue.empty()) {
auto Constraint = Queue.top();
Queue.pop();
auto Constraint = Queue.front();
Queue.pop_front();
switch (Constraint->getKind()) {
@ -947,7 +966,7 @@ namespace bolt {
{
auto Many = static_cast<CMany*>(Constraint);
for (auto Constraint: Many->Elements) {
Queue.push(Constraint);
Queue.push_back(Constraint);
}
break;
}
@ -1028,18 +1047,16 @@ namespace bolt {
};
void Checker::solveCEqual(CEqual* C) {
/* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */
// std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl;
OrigLeft = C->Left;
OrigRight = C->Right;
Source = C->Source;
unify(C->Left, C->Right);
LeftPath = {};
RightPath = {};
/* DE.add<UnificationErrorDiagnostic>(simplify(C->Left), simplify(C->Right), C->Source); */
}
Type* Checker::simplify(Type* Ty) {
Type* Checker::find(Type* Ty) {
while (Ty->getKind() == TypeKind::Var) {
auto Match = Solution.find(static_cast<TVar*>(Ty));
if (Match == Solution.end()) {
@ -1047,6 +1064,12 @@ namespace bolt {
}
Ty = Match->second;
}
return Ty;
}
Type* Checker::simplify(Type* Ty) {
Ty = find(Ty);
switch (Ty->getKind()) {
@ -1136,11 +1159,11 @@ namespace bolt {
// type such as Int. We therefore never want the variable to be polymorphic
// and be instantiated with a fresh variable, as it has already been solved.
// Should it get assigned another unification variable, that's OK too
// because then the context of that variable is what matters and not anymore
// the context of this one.
if (!Contexts.empty()) {
Contexts.back()->TVs->erase(TV);
}
// because then the inference context of that variable is what matters and
// not anymore the context of this one.
// if (!Contexts.empty()) {
// Contexts.back()->TVs->erase(TV);
// }
}
@ -1197,8 +1220,6 @@ namespace bolt {
};
bool Checker::unify(Type* A, Type* B) {
A = simplify(A);
@ -1322,6 +1343,11 @@ namespace bolt {
return Success;
}
if (llvm::isa<TTupleIndex>(A) || llvm::isa<TTupleIndex>(B)) {
Queue.push_back(C);
return true;
}
// if (llvm::isa<TTupleIndex>(A) && llvm::isa<TTupleIndex>(B)) {
// auto Index1 = static_cast<TTupleIndex*>(A);
// auto Index2 = static_cast<TTupleIndex*>(B);

View file

@ -1,4 +1,6 @@
// FIXME writeExcerpt does not work well with the last line in a file
#include <sstream>
#include <cmath>