Add Checker::makeEqual between CEqual and Checker::addConstraint

This commit is contained in:
Sam Vervaeck 2023-06-02 20:56:12 +02:00
parent ca2eb24da2
commit bf77031dd5
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
2 changed files with 16 additions and 10 deletions

View file

@ -179,6 +179,8 @@ namespace bolt {
void setContext(InferContext* Ctx); void setContext(InferContext* Ctx);
void popContext(); void popContext();
void makeEqual(Type* A, Type* B, Node* Source);
void addConstraint(Constraint* Constraint); void addConstraint(Constraint* Constraint);
/** /**

View file

@ -136,6 +136,10 @@ namespace bolt {
return *ActiveContext; return *ActiveContext;
} }
void Checker::makeEqual(Type* A, Type* B, Node* Source) {
makeEqual(A, B, Source);
}
void Checker::addConstraint(Constraint* C) { void Checker::addConstraint(Constraint* C) {
switch (C->getKind()) { switch (C->getKind()) {
case ConstraintKind::Equal: case ConstraintKind::Equal:
@ -490,7 +494,7 @@ namespace bolt {
// e.g. Bool, which causes the type assert to also collapse to e.g. // e.g. Bool, which causes the type assert to also collapse to e.g.
// Bool -> Bool -> Bool. // Bool -> Bool -> Bool.
for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) {
addConstraint(new CEqual(Param, TE->getType(), TE)); makeEqual(Param, TE->getType(), TE);
} }
// It would be very strange if there was no type assert in the type // It would be very strange if there was no type assert in the type
@ -500,7 +504,7 @@ namespace bolt {
// because we need to re-generate the type within the local context of // because we need to re-generate the type within the local context of
// this let-declaration. // this let-declaration.
// TODO make CEqual accept multiple nodes // TODO make CEqual accept multiple nodes
addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let)); makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let);
} }
} }
@ -650,14 +654,14 @@ namespace bolt {
auto E = static_cast<LetExprBody*>(Decl->Body); auto E = static_cast<LetExprBody*>(Decl->Body);
auto Ty2 = inferExpression(E->Expression); auto Ty2 = inferExpression(E->Expression);
if (Ty) { if (Ty) {
addConstraint(new CEqual(Ty, Ty2, Decl)); makeEqual(Ty, Ty2, Decl);
} else { } else {
Ty = Ty2; Ty = Ty2;
} }
} }
auto Ty3 = inferPattern(Decl->Pattern); auto Ty3 = inferPattern(Decl->Pattern);
if (Ty) { if (Ty) {
addConstraint(new CEqual(Ty, Ty3, Decl)); makeEqual(Ty, Ty3, Decl);
} else { } else {
Ty = Ty3; Ty = Ty3;
} }
@ -766,7 +770,7 @@ namespace bolt {
case NodeKind::EqualityConstraintExpression: case NodeKind::EqualityConstraintExpression:
{ {
auto D = static_cast<EqualityConstraintExpression*>(C); auto D = static_cast<EqualityConstraintExpression*>(C);
addConstraint(new CEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C)); makeEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C);
break; break;
} }
default: default:
@ -902,9 +906,9 @@ namespace bolt {
auto OldCtx = &getContext(); auto OldCtx = &getContext();
setContext(Case->Ctx); setContext(Case->Ctx);
auto PattTy = inferPattern(Case->Pattern); auto PattTy = inferPattern(Case->Pattern);
addConstraint(new CEqual(PattTy, ValTy, Case)); makeEqual(PattTy, ValTy, Case);
auto ExprTy = inferExpression(Case->Expression); auto ExprTy = inferExpression(Case->Expression);
addConstraint(new CEqual(ExprTy, Ty, Case->Expression)); makeEqual(ExprTy, Ty, Case->Expression);
setContext(OldCtx); setContext(OldCtx);
} }
if (!Match->Value) { if (!Match->Value) {
@ -1008,7 +1012,7 @@ namespace bolt {
auto K = static_cast<Identifier*>(Member->Name); auto K = static_cast<Identifier*>(Member->Name);
Ty = createTypeVar(); Ty = createTypeVar();
auto RestTy = createTypeVar(); auto RestTy = createTypeVar();
addConstraint(new CEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member)); makeEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member);
break; break;
} }
default: default:
@ -1064,7 +1068,7 @@ namespace bolt {
} }
auto Ty = instantiate(Scm, P); auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar(); auto RetTy = createTypeVar();
addConstraint(new CEqual(Ty, new TArrow(ParamTypes, RetTy), P)); makeEqual(Ty, new TArrow(ParamTypes, RetTy), P);
return RetTy; return RetTy;
} }
@ -1083,7 +1087,7 @@ namespace bolt {
auto P = static_cast<ListPattern*>(Pattern); auto P = static_cast<ListPattern*>(Pattern);
auto ElementType = createTypeVar(); auto ElementType = createTypeVar();
for (auto [Element, Separator]: P->Elements) { for (auto [Element, Separator]: P->Elements) {
addConstraint(new CEqual(ElementType, inferPattern(Element), P)); makeEqual(ElementType, inferPattern(Element), P);
} }
return new TApp(ListType, ElementType); return new TApp(ListType, ElementType);
} }