From 4294063921943d7fdc15e08cf5e0a309749d37ec Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sat, 3 Jun 2023 11:45:14 +0200 Subject: [PATCH] Simplify TArrow type as described in issue #42 --- include/bolt/Type.hpp | 23 ++++++++---- src/Checker.cc | 87 ++++++++++++++++++------------------------- src/Diagnostics.cc | 24 +++--------- src/Types.cc | 48 ++++++------------------ 4 files changed, 68 insertions(+), 114 deletions(-) diff --git a/include/bolt/Type.hpp b/include/bolt/Type.hpp index 67c1b756f..e15a873b6 100644 --- a/include/bolt/Type.hpp +++ b/include/bolt/Type.hpp @@ -81,8 +81,8 @@ namespace bolt { return { TypeIndexKind::FieldRestType }; } - static TypeIndex forArrowParamType(std::size_t I) { - return { TypeIndexKind::ArrowParamType, I }; + static TypeIndex forArrowParamType() { + return { TypeIndexKind::ArrowParamType }; } static TypeIndex forArrowReturnType() { @@ -303,16 +303,25 @@ namespace bolt { class TArrow : public Type { public: - std::vector ParamTypes; + Type* ParamType; Type* ReturnType; inline TArrow( - std::vector ParamTypes, + Type* ParamType, Type* ReturnType ): Type(TypeKind::Arrow), - ParamTypes(ParamTypes), + ParamType(ParamType), ReturnType(ReturnType) {} + static Type* build(std::vector ParamTypes, Type* ReturnType) { + Type* Curr = ReturnType; + for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) { + Curr = new TArrow(*Iter, Curr); + } + return Curr; + } + + static bool classof(const Type* Ty) { return Ty->getKind() == TypeKind::Arrow; } @@ -474,9 +483,7 @@ namespace bolt { case TypeKind::Arrow: { auto Arrow = static_cast*>(Ty); - for (auto I = 0; I < Arrow->ParamTypes.size(); ++I) { - visit(Arrow->ParamTypes[I]); - } + visit(Arrow->ParamType); visit(Arrow->ReturnType); break; } diff --git a/src/Checker.cc b/src/Checker.cc index fa9b3d2ee..1b0652c03 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -313,7 +313,7 @@ namespace bolt { for (auto Element: TupleMember->Elements) { ParamTypes.push_back(inferTypeExpression(Element)); } - Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(ParamTypes, RetTy))); + Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, TArrow::build(ParamTypes, RetTy))); break; } case NodeKind::RecordVariantDeclarationMember: @@ -358,7 +358,7 @@ namespace bolt { for (auto TV: Vars) { RetTy = new TApp(RetTy, TV); } - Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy))); + Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(FieldsTy, RetTy))); popContext(); break; @@ -570,7 +570,7 @@ namespace bolt { RetType = createTypeVar(); } - makeEqual(Decl->Ty, new TArrow(ParamTypes, RetType), Decl); + makeEqual(Decl->Ty, TArrow::build(ParamTypes, RetType), Decl); } @@ -849,7 +849,7 @@ namespace bolt { ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly)); } auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly); - auto Ty = new TArrow(ParamTypes, ReturnType); + auto Ty = TArrow::build(ParamTypes, ReturnType); N->setType(Ty); return Ty; } @@ -910,7 +910,7 @@ namespace bolt { setContext(OldCtx); } if (!Match->Value) { - Ty = new TArrow({ ValTy }, Ty); + Ty = new TArrow(ValTy, Ty); } break; } @@ -962,7 +962,7 @@ namespace bolt { for (auto Arg: Call->Args) { ArgTypes.push_back(inferExpression(Arg)); } - makeEqual(OpTy, new TArrow(ArgTypes, Ty), X); + makeEqual(OpTy, TArrow::build(ArgTypes, Ty), X); break; } @@ -979,7 +979,7 @@ namespace bolt { std::vector ArgTys; ArgTys.push_back(inferExpression(Infix->LHS)); ArgTys.push_back(inferExpression(Infix->RHS)); - makeEqual(new TArrow(ArgTys, Ty), OpTy, X); + makeEqual(TArrow::build(ArgTys, Ty), OpTy, X); break; } @@ -1066,7 +1066,7 @@ namespace bolt { } auto Ty = instantiate(Scm, P); auto RetTy = createTypeVar(); - makeEqual(Ty, new TArrow(ParamTypes, RetTy), P); + makeEqual(Ty, TArrow::build(ParamTypes, RetTy), P); return RetTy; } @@ -1181,11 +1181,11 @@ namespace bolt { addBinding("True", new Forall(BoolType)); addBinding("False", new Forall(BoolType)); auto A = createTypeVar(); - addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, new TArrow({ A, A }, BoolType))); - addBinding("+", new Forall(new TArrow({ IntType, IntType }, IntType))); - addBinding("-", new Forall(new TArrow({ IntType, IntType }, IntType))); - addBinding("*", new Forall(new TArrow({ IntType, IntType }, IntType))); - addBinding("/", new Forall(new TArrow({ IntType, IntType }, IntType))); + addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, TArrow::build({ A, A }, BoolType))); + addBinding("+", new Forall(TArrow::build({ IntType, IntType }, IntType))); + addBinding("-", new Forall(TArrow::build({ IntType, IntType }, IntType))); + addBinding("*", new Forall(TArrow::build({ IntType, IntType }, IntType))); + addBinding("/", new Forall(TArrow::build({ IntType, IntType }, IntType))); populate(SF); forwardDeclare(SF); auto SCCs = RefGraph.strongconnect(); @@ -1593,62 +1593,47 @@ namespace bolt { } if (llvm::isa(A) && llvm::isa(B)) { - auto C1 = ArrowCursor(static_cast(A), DidSwap ? RightPath : LeftPath); - auto C2 = ArrowCursor(static_cast(B), DidSwap ? LeftPath : RightPath); + auto Arrow1 = static_cast(A); + auto Arrow2 = static_cast(B); bool Success = true; - for (;;) { - auto T1 = C1.next(); - auto T2 = C2.next(); - if (T1 == nullptr && T2 == nullptr) { - break; - } - if (T1 == nullptr || T2 == nullptr) { - unifyError(); - Success = false; - break; - } - if (!unify(T1, T2, DidSwap)) { - Success = false; - } + LeftPath.push_back(TypeIndex::forArrowParamType()); + RightPath.push_back(TypeIndex::forArrowParamType()); + if (!unify(Arrow1->ParamType, Arrow2->ParamType, DidSwap)) { + Success = false; } + LeftPath.pop_back(); + RightPath.pop_back(); + LeftPath.push_back(TypeIndex::forArrowReturnType()); + RightPath.push_back(TypeIndex::forArrowReturnType()); + if (!unify(Arrow1->ReturnType, Arrow2->ReturnType, DidSwap)) { + Success = false; + } + LeftPath.pop_back(); + RightPath.pop_back(); return Success; - /* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */ - /* return false; */ - /* } */ - /* auto Count = Arr1->ParamTypes.size(); */ - /* for (std::size_t I = 0; I < Count; I++) { */ - /* if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) { */ - /* return false; */ - /* } */ - /* } */ - /* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */ } if (llvm::isa(A) && llvm::isa(B)) { auto App1 = static_cast(A); auto App2 = static_cast(B); bool Success = true; + LeftPath.push_back(TypeIndex::forAppOpType()); + RightPath.push_back(TypeIndex::forAppOpType()); if (!unify(App1->Op, App2->Op, DidSwap)) { Success = false; } + LeftPath.pop_back(); + RightPath.pop_back(); + LeftPath.push_back(TypeIndex::forAppArgType()); + RightPath.push_back(TypeIndex::forAppArgType()); if (!unify(App1->Arg, App2->Arg, DidSwap)) { Success = false; } + LeftPath.pop_back(); + RightPath.pop_back(); return Success; } - if (llvm::isa(B)) { - swap(); - } - - if (llvm::isa(A)) { - auto Arr = static_cast(A); - if (Arr->ParamTypes.empty()) { - auto Success = unify(Arr->ReturnType, B, DidSwap); - return Success; - } - } - if (llvm::isa(A) && llvm::isa(B)) { auto Tuple1 = static_cast(A); auto Tuple2 = static_cast(B); diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 3ce5a2d0c..30569f8a4 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -165,14 +165,7 @@ namespace bolt { { auto Y = static_cast(Ty); std::ostringstream Out; - Out << "("; - bool First = true; - for (auto PT: Y->ParamTypes) { - if (First) First = false; - else Out << ", "; - Out << describe(PT); - } - Out << ") -> " << describe(Y->ReturnType); + Out << describe(Y->ParamType) << " -> " << describe(Y->ReturnType); return Out.str(); } case TypeKind::Con: @@ -558,17 +551,10 @@ namespace bolt { } void visitArrowType(const TArrow* Ty) override { - W.write("("); - bool First = true; - std::size_t I = 0; - for (auto PT: Ty->ParamTypes) { - if (First) First = false; - else W.write(", "); - Path.push_back(TypeIndex::forArrowParamType(I++)); - visit(PT); - Path.pop_back(); - } - W.write(") -> "); + Path.push_back(TypeIndex::forArrowParamType()); + visit(Ty->ParamType); + Path.pop_back(); + W.write(" -> "); Path.push_back(TypeIndex::forArrowReturnType()); visit(Ty->ReturnType); Path.pop_back(); diff --git a/src/Types.cc b/src/Types.cc index 1b8cc43f3..a198ef624 100644 --- a/src/Types.cc +++ b/src/Types.cc @@ -44,15 +44,11 @@ namespace bolt { Kind = TypeIndexKind::AppArgType; break; case TypeIndexKind::ArrowParamType: - { - auto Arrow = llvm::cast(Ty); - if (I+1 < Arrow->ParamTypes.size()) { - ++I; - } else { - Kind = TypeIndexKind::ArrowReturnType; - } + Kind = TypeIndexKind::ArrowReturnType; + break; + case TypeIndexKind::ArrowReturnType: + Kind = TypeIndexKind::End; break; - } case TypeIndexKind::FieldType: Kind = TypeIndexKind::FieldRestType; break; @@ -60,9 +56,6 @@ namespace bolt { case TypeIndexKind::TupleIndexType: case TypeIndexKind::PresentType: case TypeIndexKind::AppArgType: - case TypeIndexKind::ArrowReturnType: - Kind = TypeIndexKind::End; - break; case TypeIndexKind::TupleElement: { auto Tuple = llvm::cast(Ty); @@ -88,19 +81,15 @@ namespace bolt { { auto Arrow = static_cast(Ty2); bool Changed = false; - std::vector NewParamTypes; - for (auto Ty: Arrow->ParamTypes) { - auto NewParamType = Ty->rewrite(Fn); - if (NewParamType != Ty) { - Changed = true; - } - NewParamTypes.push_back(NewParamType); + Type* NewParamType = Arrow->ParamType->rewrite(Fn); + if (NewParamType != Arrow->ParamType) { + Changed = true; } auto NewRetTy = Arrow->ReturnType->rewrite(Fn); if (NewRetTy != Arrow->ReturnType) { Changed = true; } - return Changed ? new TArrow(NewParamTypes, NewRetTy) : Ty2; + return Changed ? new TArrow(NewParamType, NewRetTy) : Ty2; } case TypeKind::Con: return Ty2; @@ -173,9 +162,7 @@ namespace bolt { case TypeKind::Arrow: { auto Arrow = static_cast(this); - for (auto Ty: Arrow->ParamTypes) { - Ty->addTypeVars(TVs); - } + Arrow->ParamType->addTypeVars(TVs); Arrow->ReturnType->addTypeVars(TVs); break; } @@ -229,12 +216,7 @@ namespace bolt { case TypeKind::Arrow: { auto Arrow = static_cast(this); - for (auto Ty: Arrow->ParamTypes) { - if (Ty->hasTypeVar(TV)) { - return true; - } - } - return Arrow->ReturnType->hasTypeVar(TV); + return Arrow->ParamType->hasTypeVar(TV) || Arrow->ReturnType->hasTypeVar(TV); } case TypeKind::Con: return false; @@ -308,7 +290,7 @@ namespace bolt { case TypeIndexKind::TupleElement: return llvm::cast(this)->ElementTypes[Index.I]; case TypeIndexKind::ArrowParamType: - return llvm::cast(this)->ParamTypes[Index.I]; + return llvm::cast(this)->ParamType; case TypeIndexKind::ArrowReturnType: return llvm::cast(this)->ReturnType; case TypeIndexKind::FieldType: @@ -445,13 +427,7 @@ namespace bolt { TypeIndex Type::getStartIndex() { switch (Kind) { case TypeKind::Arrow: - { - auto Arrow = static_cast(this); - if (Arrow->ParamTypes.empty()) { - return TypeIndex::forArrowReturnType(); - } - return TypeIndex::forArrowParamType(0); - } + return TypeIndex::forArrowParamType(); case TypeKind::Tuple: { auto Tuple = static_cast(this);