Add Checker::makeEqual between CEqual and Checker::addConstraint

This commit is contained in:
Sam Vervaeck 2023-06-02 20:56:12 +02:00
parent ca2eb24da2
commit bf77031dd5
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
2 changed files with 16 additions and 10 deletions

View file

@ -179,6 +179,8 @@ namespace bolt {
void setContext(InferContext* Ctx);
void popContext();
void makeEqual(Type* A, Type* B, Node* Source);
void addConstraint(Constraint* Constraint);
/**

View file

@ -136,6 +136,10 @@ namespace bolt {
return *ActiveContext;
}
void Checker::makeEqual(Type* A, Type* B, Node* Source) {
makeEqual(A, B, Source);
}
void Checker::addConstraint(Constraint* C) {
switch (C->getKind()) {
case ConstraintKind::Equal:
@ -490,7 +494,7 @@ namespace bolt {
// e.g. Bool, which causes the type assert to also collapse to e.g.
// Bool -> Bool -> Bool.
for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) {
addConstraint(new CEqual(Param, TE->getType(), TE));
makeEqual(Param, TE->getType(), TE);
}
// It would be very strange if there was no type assert in the type
@ -500,7 +504,7 @@ namespace bolt {
// because we need to re-generate the type within the local context of
// this let-declaration.
// TODO make CEqual accept multiple nodes
addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let));
makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let);
}
}
@ -650,14 +654,14 @@ namespace bolt {
auto E = static_cast<LetExprBody*>(Decl->Body);
auto Ty2 = inferExpression(E->Expression);
if (Ty) {
addConstraint(new CEqual(Ty, Ty2, Decl));
makeEqual(Ty, Ty2, Decl);
} else {
Ty = Ty2;
}
}
auto Ty3 = inferPattern(Decl->Pattern);
if (Ty) {
addConstraint(new CEqual(Ty, Ty3, Decl));
makeEqual(Ty, Ty3, Decl);
} else {
Ty = Ty3;
}
@ -766,7 +770,7 @@ namespace bolt {
case NodeKind::EqualityConstraintExpression:
{
auto D = static_cast<EqualityConstraintExpression*>(C);
addConstraint(new CEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C));
makeEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C);
break;
}
default:
@ -902,9 +906,9 @@ namespace bolt {
auto OldCtx = &getContext();
setContext(Case->Ctx);
auto PattTy = inferPattern(Case->Pattern);
addConstraint(new CEqual(PattTy, ValTy, Case));
makeEqual(PattTy, ValTy, Case);
auto ExprTy = inferExpression(Case->Expression);
addConstraint(new CEqual(ExprTy, Ty, Case->Expression));
makeEqual(ExprTy, Ty, Case->Expression);
setContext(OldCtx);
}
if (!Match->Value) {
@ -1008,7 +1012,7 @@ namespace bolt {
auto K = static_cast<Identifier*>(Member->Name);
Ty = createTypeVar();
auto RestTy = createTypeVar();
addConstraint(new CEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member));
makeEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member);
break;
}
default:
@ -1064,7 +1068,7 @@ namespace bolt {
}
auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar();
addConstraint(new CEqual(Ty, new TArrow(ParamTypes, RetTy), P));
makeEqual(Ty, new TArrow(ParamTypes, RetTy), P);
return RetTy;
}
@ -1083,7 +1087,7 @@ namespace bolt {
auto P = static_cast<ListPattern*>(Pattern);
auto ElementType = createTypeVar();
for (auto [Element, Separator]: P->Elements) {
addConstraint(new CEqual(ElementType, inferPattern(Element), P));
makeEqual(ElementType, inferPattern(Element), P);
}
return new TApp(ListType, ElementType);
}