Fix some regressions due to previous commits

This commit is contained in:
Sam Vervaeck 2023-05-30 15:27:21 +02:00
parent a8f8658f27
commit 63547ee0a5
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
4 changed files with 46 additions and 13 deletions

View file

@ -160,9 +160,9 @@ namespace bolt {
case NodeKind::LetDeclaration:
return static_cast<D*>(this)->visitLetDeclaration(static_cast<LetDeclaration*>(N));
case NodeKind::RecordDeclarationField:
return static_cast<D*>(this)->visitStructDeclarationField(static_cast<RecordDeclarationField*>(N));
return static_cast<D*>(this)->visitRecordDeclarationField(static_cast<RecordDeclarationField*>(N));
case NodeKind::RecordDeclaration:
return static_cast<D*>(this)->visitStructDeclaration(static_cast<RecordDeclaration*>(N));
return static_cast<D*>(this)->visitRecordDeclaration(static_cast<RecordDeclaration*>(N));
case NodeKind::VariantDeclaration:
return static_cast<D*>(this)->visitVariantDeclaration(static_cast<VariantDeclaration*>(N));
case NodeKind::TupleVariantDeclarationMember:
@ -504,11 +504,11 @@ namespace bolt {
visitNode(N);
}
void visitStructDeclarationField(RecordDeclarationField* N) {
void visitRecordDeclarationField(RecordDeclarationField* N) {
visitNode(N);
}
void visitStructDeclaration(RecordDeclaration* N) {
void visitRecordDeclaration(RecordDeclaration* N) {
visitNode(N);
}

View file

@ -89,6 +89,7 @@ namespace bolt {
switch (X->getKind()) {
case NodeKind::LetExprBody:
case NodeKind::ExpressionStatement:
case NodeKind::IfStatement:
case NodeKind::ReturnStatement:
break;
case NodeKind::LetBlockBody:
@ -99,6 +100,9 @@ namespace bolt {
}
break;
}
case NodeKind::InstanceDeclaration:
// We ignore let-declarations inside instance-declarations for now
break;
case NodeKind::ClassDeclaration:
{
auto Decl = static_cast<ClassDeclaration*>(X);

View file

@ -166,6 +166,9 @@ namespace bolt {
{
auto Y = static_cast<CEqual*>(C);
// This will store all inference contexts in Contexts, from most local
// one to most general one. Because this order is not ideal, the code
// below will have to handle that.
auto Curr = &getContext();
std::vector<InferContext*> Contexts;
for (;;) {
@ -176,28 +179,42 @@ namespace bolt {
}
}
// If no MaxLevelLeft was found, that means that not a single
// corresponding type variable was found in the contexts. We set it to
// 0, which corresponds to the global inference context.
std::size_t MaxLevelLeft = 0;
for (std::size_t I = 0; I < Contexts.size(); I++) {
auto Ctx = Contexts[I];
if (hasTypeVar(*Ctx->TVs, Y->Left)) {
MaxLevelLeft = I;
MaxLevelLeft = Contexts.size() - I - 1;
break;
}
}
// Same as above but now mirrored for Y->Right
std::size_t MaxLevelRight = 0;
for (std::size_t I = 0; I < Contexts.size(); I++) {
auto Ctx = Contexts[I];
if (hasTypeVar(*Ctx->TVs, Y->Right)) {
MaxLevelRight = I;
MaxLevelRight = Contexts.size() - I - 1;
break;
}
}
// The lowest index is determined by the one that has no type variables
// in Y->Left AND in Y->Right. This implies max() must be used, so that
// the very first enounter of a type variable matters.
auto MaxLevel = std::max(MaxLevelLeft, MaxLevelRight);
// Now find the highest index I such that all the contexts that are more
// local do not contain any type variables that are present in the
// equality constraint.
std::size_t MinLevel = MaxLevel;
for (std::size_t I = Contexts.size(); I-- > 0; ) {
auto Ctx = Contexts[I];
if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
// No need to reverse because even though Contexts is reversed, we
// are also iterating in reverse.
MinLevel = I;
break;
}
@ -206,7 +223,7 @@ namespace bolt {
if (MaxLevel == MinLevel || MaxLevelLeft == 0 || MaxLevelRight == 0) {
solveCEqual(Y);
} else {
Contexts[MaxLevel]->Constraints->push_back(C);
Contexts[Contexts.size() - MaxLevel - 1]->Constraints->push_back(C);
}
break;
@ -341,7 +358,6 @@ namespace bolt {
std::vector<TVar*> Vars;
for (auto TE: Decl->Vars) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV);
Vars.push_back(TV);
}
@ -517,10 +533,10 @@ namespace bolt {
}
}
popContext();
Type* BindTy;
if (Let->isFunc()) {
popContext();
BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs);
} else {
BindTy = inferPattern(Let->Pattern);
@ -566,9 +582,8 @@ namespace bolt {
RetType = createTypeVar();
}
popContext();
if (Decl->isFunc()) {
popContext();
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl });
} else {
// Declaration is a plain (typed) variable
@ -698,12 +713,20 @@ namespace bolt {
for (auto Constraint: *F->Constraints) {
// FIXME improve this
if (Constraint->getKind() == ConstraintKind::Equal) {
auto Eq = static_cast<CEqual*>(Constraint);
Eq->Left = simplifyType(Eq->Left);
Eq->Right = simplifyType(Eq->Right);
}
auto NewConstraint = Constraint->substitute(Sub);
// This makes error messages prettier by relating the typing failure
// to the call site rather than the definition.
if (NewConstraint->getKind() == ConstraintKind::Equal) {
static_cast<CEqual*>(NewConstraint)->Source = Source;
auto Eq = static_cast<CEqual*>(Constraint);
Eq->Source = Source;
}
addConstraint(NewConstraint);
@ -868,7 +891,7 @@ namespace bolt {
for (auto Case: Match->Cases) {
pushContext(Case->Ctx);
auto PattTy = inferPattern(Case->Pattern);
addConstraint(new CEqual(PattTy, ValTy, X));
addConstraint(new CEqual(PattTy, ValTy, Case));
auto ExprTy = inferExpression(Case->Expression);
addConstraint(new CEqual(ExprTy, Ty, Case->Expression));
popContext();

View file

@ -118,6 +118,12 @@ namespace bolt {
return "a function or variable reference";
case NodeKind::MatchExpression:
return "a match-expression";
case NodeKind::ConstantExpression:
return "a literal expression";
case NodeKind::IfStatement:
return "an if-statement";
case NodeKind::IfStatementPart:
return "a branch of an if-statement";
default:
ZEN_UNREACHABLE
}