diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index faf898e29..f310b7ed9 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -179,6 +179,8 @@ namespace bolt { void setContext(InferContext* Ctx); void popContext(); + void makeEqual(Type* A, Type* B, Node* Source); + void addConstraint(Constraint* Constraint); /** diff --git a/src/Checker.cc b/src/Checker.cc index 9cfc36c5c..0f2f4b19b 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -136,6 +136,10 @@ namespace bolt { return *ActiveContext; } + void Checker::makeEqual(Type* A, Type* B, Node* Source) { + makeEqual(A, B, Source); + } + void Checker::addConstraint(Constraint* C) { switch (C->getKind()) { case ConstraintKind::Equal: @@ -490,7 +494,7 @@ namespace bolt { // e.g. Bool, which causes the type assert to also collapse to e.g. // Bool -> Bool -> Bool. for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { - addConstraint(new CEqual(Param, TE->getType(), TE)); + makeEqual(Param, TE->getType(), TE); } // It would be very strange if there was no type assert in the type @@ -500,7 +504,7 @@ namespace bolt { // because we need to re-generate the type within the local context of // this let-declaration. // TODO make CEqual accept multiple nodes - addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let)); + makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let); } } @@ -650,14 +654,14 @@ namespace bolt { auto E = static_cast(Decl->Body); auto Ty2 = inferExpression(E->Expression); if (Ty) { - addConstraint(new CEqual(Ty, Ty2, Decl)); + makeEqual(Ty, Ty2, Decl); } else { Ty = Ty2; } } auto Ty3 = inferPattern(Decl->Pattern); if (Ty) { - addConstraint(new CEqual(Ty, Ty3, Decl)); + makeEqual(Ty, Ty3, Decl); } else { Ty = Ty3; } @@ -766,7 +770,7 @@ namespace bolt { case NodeKind::EqualityConstraintExpression: { auto D = static_cast(C); - addConstraint(new CEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C)); + makeEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C); break; } default: @@ -902,9 +906,9 @@ namespace bolt { auto OldCtx = &getContext(); setContext(Case->Ctx); auto PattTy = inferPattern(Case->Pattern); - addConstraint(new CEqual(PattTy, ValTy, Case)); + makeEqual(PattTy, ValTy, Case); auto ExprTy = inferExpression(Case->Expression); - addConstraint(new CEqual(ExprTy, Ty, Case->Expression)); + makeEqual(ExprTy, Ty, Case->Expression); setContext(OldCtx); } if (!Match->Value) { @@ -1008,7 +1012,7 @@ namespace bolt { auto K = static_cast(Member->Name); Ty = createTypeVar(); auto RestTy = createTypeVar(); - addConstraint(new CEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member)); + makeEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member); break; } default: @@ -1064,7 +1068,7 @@ namespace bolt { } auto Ty = instantiate(Scm, P); auto RetTy = createTypeVar(); - addConstraint(new CEqual(Ty, new TArrow(ParamTypes, RetTy), P)); + makeEqual(Ty, new TArrow(ParamTypes, RetTy), P); return RetTy; } @@ -1083,7 +1087,7 @@ namespace bolt { auto P = static_cast(Pattern); auto ElementType = createTypeVar(); for (auto [Element, Separator]: P->Elements) { - addConstraint(new CEqual(ElementType, inferPattern(Element), P)); + makeEqual(ElementType, inferPattern(Element), P); } return new TApp(ListType, ElementType); }