From 302823ac9bc572b3759076c6a392a5f1e68e06ce Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Mon, 22 May 2023 22:37:58 +0200 Subject: [PATCH] Split up Checker.hpp and make room for better type mismatch errors --- CMakeLists.txt | 1 + include/bolt/Checker.hpp | 198 ++++--------------- include/bolt/Diagnostics.hpp | 27 +-- include/bolt/Type.hpp | 282 ++++++++++++++++++++++++++ src/Checker.cc | 351 +++++++++++---------------------- src/Diagnostics.cc | 21 +- src/Types.cc | 369 +++++++++++++++++++++++++++++++++++ 7 files changed, 821 insertions(+), 428 deletions(-) create mode 100644 include/bolt/Type.hpp create mode 100644 src/Types.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 503ab5c25..d2ba796ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ add_library( src/Diagnostics.cc src/Scanner.cc src/Parser.cc + src/Types.cc src/Checker.cc src/IPRGraph.cc ) diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index a8d26582b..43b2682d2 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -6,170 +6,15 @@ #include "bolt/ByteString.hpp" #include "bolt/Common.hpp" #include "bolt/CST.hpp" -#include "bolt/Diagnostics.hpp" +#include "bolt/Type.hpp" -#include #include #include #include -#include namespace bolt { class DiagnosticEngine; - class Node; - - class Type; - class TVar; - - using TVSub = std::unordered_map; - using TVSet = std::unordered_set; - - using TypeclassContext = std::unordered_set; - - enum class TypeKind : unsigned char { - Var, - Con, - Arrow, - Tuple, - TupleIndex, - }; - - class Type { - - const TypeKind Kind; - - protected: - - inline Type(TypeKind Kind): - Kind(Kind) {} - - public: - - bool hasTypeVar(const TVar* TV); - - void addTypeVars(TVSet& TVs); - - inline TVSet getTypeVars() { - TVSet Out; - addTypeVars(Out); - return Out; - } - - Type* substitute(const TVSub& Sub); - - inline TypeKind getKind() const noexcept { - return Kind; - } - - }; - - class TCon : public Type { - public: - - const size_t Id; - std::vector Args; - ByteString DisplayName; - - inline TCon(const size_t Id, std::vector Args, ByteString DisplayName): - Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Con; - } - - }; - - enum class VarKind { - Rigid, - Unification, - }; - - class TVar : public Type { - public: - - const size_t Id; - VarKind VK; - - TypeclassContext Contexts; - - inline TVar(size_t Id, VarKind VK): - Type(TypeKind::Var), Id(Id), VK(VK) {} - - inline VarKind getVarKind() const noexcept { - return VK; - } - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Var; - } - - }; - - class TVarRigid : public TVar { - public: - - ByteString Name; - - inline TVarRigid(size_t Id, ByteString Name): - TVar(Id, VarKind::Rigid), Name(Name) {} - - }; - - class TArrow : public Type { - public: - - std::vector ParamTypes; - Type* ReturnType; - - inline TArrow( - std::vector ParamTypes, - Type* ReturnType - ): Type(TypeKind::Arrow), - ParamTypes(ParamTypes), - ReturnType(ReturnType) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Arrow; - } - - }; - - class TTuple : public Type { - public: - - std::vector ElementTypes; - - inline TTuple(std::vector ElementTypes): - Type(TypeKind::Tuple), ElementTypes(ElementTypes) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Tuple; - } - - }; - - class TTupleIndex : public Type { - public: - - Type* Ty; - std::size_t I; - - inline TTupleIndex(Type* Ty, std::size_t I): - Type(TypeKind::TupleIndex), Ty(Ty), I(I) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::TupleIndex; - } - - }; - - // template - // struct DerefHash { - // std::size_t operator()(const T& Value) const noexcept { - // return std::hash{}(*Value); - // } - // }; class Constraint; @@ -354,6 +199,8 @@ namespace bolt { std::unordered_map CallGraph; + std::unordered_map> InstanceMap; + Type* BoolType; Type* IntType; Type* StringType; @@ -412,20 +259,43 @@ namespace bolt { Type* instantiate(Scheme* S, Node* Source); - std::unordered_map> InstanceMap; - std::vector findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source); - void propagateClasses(TypeclassContext& Classes, Type* Ty, Node* Source); - void propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source); - - void checkTypeclassSigs(Node* N); + std::vector findInstanceContext(TCon* Ty, TypeclassId& Class); + void propagateClasses(TypeclassContext& Classes, Type* Ty); + void propagateClassTycon(TypeclassId& Class, TCon* Ty); Type* simplify(Type* Ty); - void join(TVar* A, Type* B, Node* Source); - bool unify(Type* A, Type* B, Node* Source); + /** + * Assign a type to a unification variable. + * + * If there are class constraints, those are propagated. + * + * If this type variable is solved during inference, it will be removed from + * the inference context. + * + * Other side effects may occur. + */ + void join(TVar* A, Type* B); + + Type* OrigLeft; + Type* OrigRight; + TypePath LeftPath; + TypePath RightPath; + Node* Source; + + bool unify(Type* A, Type* B); + + void unifyError(); void solveCEqual(CEqual* C); + void solve(Constraint* Constraint, TVSub& Solution); + /** + * Verifies that type class signatures on type asserts in let-declarations + * correctly declare the right type classes. + */ + void checkTypeclassSigs(Node* N); + public: Checker(const LanguageConfig& Config, DiagnosticEngine& DE); diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index dbd7f8f16..32453b9ca 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -9,27 +9,10 @@ #include "bolt/ByteString.hpp" #include "bolt/String.hpp" #include "bolt/CST.hpp" +#include "bolt/Type.hpp" namespace bolt { - class Type; - class TCon; - class TVar; - class TTuple; - - using TypeclassId = ByteString; - - struct TypeclassSignature { - - using TypeclassId = ByteString; - TypeclassId Id; - std::vector Params; - - bool operator<(const TypeclassSignature& Other) const; - bool operator==(const TypeclassSignature& Other) const; - - }; - enum class DiagnosticKind : unsigned char { UnexpectedToken, UnexpectedString, @@ -95,13 +78,15 @@ namespace bolt { class UnificationErrorDiagnostic : public Diagnostic { public: - + Type* Left; Type* Right; + TypePath LeftPath; + TypePath RightPath; Node* Source; - inline UnificationErrorDiagnostic(Type* Left, Type* Right, Node* Source): - Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), Source(Source) {} + inline UnificationErrorDiagnostic(Type* Left, Type* Right, TypePath LeftPath, TypePath RightPath, Node* Source): + Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {} }; diff --git a/include/bolt/Type.hpp b/include/bolt/Type.hpp new file mode 100644 index 000000000..704449564 --- /dev/null +++ b/include/bolt/Type.hpp @@ -0,0 +1,282 @@ + +#pragma once + +#include +#include +#include + +#include "bolt/ByteString.hpp" + +namespace bolt { + + class Type; + class TVar; + + using TypeclassId = ByteString; + + using TypeclassContext = std::unordered_set; + + struct TypeclassSignature { + + using TypeclassId = ByteString; + TypeclassId Id; + std::vector Params; + + bool operator<(const TypeclassSignature& Other) const; + bool operator==(const TypeclassSignature& Other) const; + + }; + + enum class TypeIndexKind { + ArrowParamType, + ArrowReturnType, + ConArg, + TupleElement, + End, + }; + + class TypeIndex { + protected: + + friend class Type; + friend class TypeIterator; + + TypeIndexKind Kind; + + union { + std::size_t I; + }; + + TypeIndex(TypeIndexKind Kind): + Kind(Kind) {} + + TypeIndex(TypeIndexKind Kind, std::size_t I): + Kind(Kind), I(I) {} + + public: + + bool operator==(const TypeIndex& Other) const noexcept; + + void advance(const Type* Ty); + + static TypeIndex forArrowReturnType() { + return { TypeIndexKind::ArrowReturnType }; + } + + static TypeIndex forArrowParamType(std::size_t I) { + return { TypeIndexKind::ArrowParamType, I }; + } + + static TypeIndex forConArg(std::size_t I) { + return { TypeIndexKind::ConArg, I }; + } + + static TypeIndex forTupleElement(std::size_t I) { + return { TypeIndexKind::TupleElement, I }; + } + + }; + + class TypeIterator { + + friend class Type; + + Type* Ty; + TypeIndex Index; + + TypeIterator(Type* Ty, TypeIndex Index): + Ty(Ty), Index(Index) {} + + public: + + TypeIterator& operator++() noexcept { + Index.advance(Ty); + return *this; + } + + bool operator==(const TypeIterator& Other) const noexcept { + return Ty == Other.Ty && Index == Other.Index; + } + + Type* operator*() { + return Ty; + } + + TypeIndex getIndex() const noexcept { + return Index; + } + + }; + + using TypePath = std::vector; + + using TVSub = std::unordered_map; + using TVSet = std::unordered_set; + + enum class TypeKind : unsigned char { + Var, + Con, + Arrow, + Tuple, + TupleIndex, + }; + + class Type { + + const TypeKind Kind; + + protected: + + inline Type(TypeKind Kind): + Kind(Kind) {} + + public: + + inline TypeKind getKind() const noexcept { + return Kind; + } + + bool hasTypeVar(const TVar* TV); + + void addTypeVars(TVSet& TVs); + + inline TVSet getTypeVars() { + TVSet Out; + addTypeVars(Out); + return Out; + } + + Type* substitute(const TVSub& Sub); + + TypeIterator begin(); + TypeIterator end(); + + TypeIndex getStartIndex(); + TypeIndex getEndIndex(); + + Type* resolve(const TypeIndex& Index) const noexcept; + + Type* resolve(const TypePath& Path) noexcept { + Type* Ty = this; + for (auto El: Path) { + Ty = Ty->resolve(El); + } + return Ty; + } + + bool operator==(const Type& Other) const noexcept; + + bool operator!=(const Type& Other) const noexcept { + return !(*this == Other); + } + + }; + + class TCon : public Type { + public: + + const size_t Id; + std::vector Args; + ByteString DisplayName; + + inline TCon(const size_t Id, std::vector Args, ByteString DisplayName): + Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Con; + } + + }; + + enum class VarKind { + Rigid, + Unification, + }; + + class TVar : public Type { + public: + + const size_t Id; + VarKind VK; + + TypeclassContext Contexts; + + inline TVar(size_t Id, VarKind VK): + Type(TypeKind::Var), Id(Id), VK(VK) {} + + inline VarKind getVarKind() const noexcept { + return VK; + } + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Var; + } + + }; + + class TVarRigid : public TVar { + public: + + ByteString Name; + + inline TVarRigid(size_t Id, ByteString Name): + TVar(Id, VarKind::Rigid), Name(Name) {} + + }; + + class TArrow : public Type { + public: + + std::vector ParamTypes; + Type* ReturnType; + + inline TArrow( + std::vector ParamTypes, + Type* ReturnType + ): Type(TypeKind::Arrow), + ParamTypes(ParamTypes), + ReturnType(ReturnType) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Arrow; + } + + }; + + class TTuple : public Type { + public: + + std::vector ElementTypes; + + inline TTuple(std::vector ElementTypes): + Type(TypeKind::Tuple), ElementTypes(ElementTypes) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Tuple; + } + + }; + + class TTupleIndex : public Type { + public: + + Type* Ty; + std::size_t I; + + inline TTupleIndex(Type* Ty, std::size_t I): + Type(TypeKind::TupleIndex), Ty(Ty), I(I) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::TupleIndex; + } + + }; + + // template + // struct DerefHash { + // std::size_t operator()(const T& Value) const noexcept { + // return std::hash{}(*Value); + // } + // }; + +} diff --git a/src/Checker.cc b/src/Checker.cc index 86e59f2c9..7b24fb051 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -25,165 +25,6 @@ namespace bolt { std::string describe(const Type* Ty); - bool TypeclassSignature::operator<(const TypeclassSignature& Other) const { - if (Id < Other.Id) { - return true; - } - ZEN_ASSERT(Params.size() == 1); - ZEN_ASSERT(Other.Params.size() == 1); - return Params[0]->Id < Other.Params[0]->Id; - } - - bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { - ZEN_ASSERT(Params.size() == 1); - ZEN_ASSERT(Other.Params.size() == 1); - return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id; - } - - void Type::addTypeVars(TVSet& TVs) { - switch (Kind) { - case TypeKind::Var: - TVs.emplace(static_cast(this)); - break; - case TypeKind::Arrow: - { - auto Arrow = static_cast(this); - for (auto Ty: Arrow->ParamTypes) { - Ty->addTypeVars(TVs); - } - Arrow->ReturnType->addTypeVars(TVs); - break; - } - case TypeKind::Con: - { - auto Con = static_cast(this); - for (auto Ty: Con->Args) { - Ty->addTypeVars(TVs); - } - break; - } - case TypeKind::TupleIndex: - { - auto Index = static_cast(this); - Index->Ty->addTypeVars(TVs); - break; - } - case TypeKind::Tuple: - { - auto Tuple = static_cast(this); - for (auto Ty: Tuple->ElementTypes) { - Ty->addTypeVars(TVs); - } - break; - } - } - } - - bool Type::hasTypeVar(const TVar* TV) { - switch (Kind) { - case TypeKind::Var: - return static_cast(this)->Id == TV->Id; - case TypeKind::Arrow: - { - auto Arrow = static_cast(this); - for (auto Ty: Arrow->ParamTypes) { - if (Ty->hasTypeVar(TV)) { - return true; - } - } - return Arrow->ReturnType->hasTypeVar(TV); - } - case TypeKind::Con: - { - auto Con = static_cast(this); - for (auto Ty: Con->Args) { - if (Ty->hasTypeVar(TV)) { - return true; - } - } - return false; - } - case TypeKind::TupleIndex: - { - auto Index = static_cast(this); - return Index->Ty->hasTypeVar(TV); - } - case TypeKind::Tuple: - { - auto Tuple = static_cast(this); - for (auto Ty: Tuple->ElementTypes) { - if (Ty->hasTypeVar(TV)) { - return true; - } - } - return false; - } - } - } - - Type* Type::substitute(const TVSub &Sub) { - switch (Kind) { - case TypeKind::Var: - { - auto TV = static_cast(this); - auto Match = Sub.find(TV); - return Match != Sub.end() ? Match->second->substitute(Sub) : this; - } - case TypeKind::Arrow: - { - auto Arrow = static_cast(this); - bool Changed = false; - std::vector NewParamTypes; - for (auto Ty: Arrow->ParamTypes) { - auto NewParamType = Ty->substitute(Sub); - if (NewParamType != Ty) { - Changed = true; - } - NewParamTypes.push_back(NewParamType); - } - auto NewRetTy = Arrow->ReturnType->substitute(Sub) ; - if (NewRetTy != Arrow->ReturnType) { - Changed = true; - } - return Changed ? new TArrow(NewParamTypes, NewRetTy) : this; - } - case TypeKind::Con: - { - auto Con = static_cast(this); - bool Changed = false; - std::vector NewArgs; - for (auto Arg: Con->Args) { - auto NewArg = Arg->substitute(Sub); - if (NewArg != Arg) { - Changed = true; - } - NewArgs.push_back(NewArg); - } - return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this; - } - case TypeKind::TupleIndex: - { - auto Tuple = static_cast(this); - auto NewTy = Tuple->Ty->substitute(Sub); - return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple; - } - case TypeKind::Tuple: - { - auto Tuple = static_cast(this); - bool Changed = false; - std::vector NewElementTypes; - for (auto Ty: Tuple->ElementTypes) { - auto NewElementType = Ty->substitute(Sub); - if (NewElementType != Ty) { - Changed = true; - } - NewElementTypes.push_back(NewElementType); - } - return Changed ? new TTuple(NewElementTypes) : this; - } - } - } - Constraint* Constraint::substitute(const TVSub &Sub) { switch (Kind) { case ConstraintKind::Class: @@ -1141,7 +982,7 @@ namespace bolt { ZEN_UNREACHABLE } - std::vector Checker::findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source) { + std::vector Checker::findInstanceContext(TCon* Ty, TypeclassId& Class) { auto Match = InstanceMap.find(Class); std::vector S; if (Match != InstanceMap.end()) { @@ -1164,7 +1005,7 @@ namespace bolt { return S; } - void Checker::propagateClasses(std::unordered_set& Classes, Type* Ty, Node* Source) { + void Checker::propagateClasses(std::unordered_set& Classes, Type* Ty) { if (llvm::isa(Ty)) { auto TV = llvm::cast(Ty); for (auto Class: Classes) { @@ -1172,61 +1013,29 @@ namespace bolt { } } else if (llvm::isa(Ty)) { for (auto Class: Classes) { - propagateClassTycon(Class, llvm::cast(Ty), Source); + propagateClassTycon(Class, llvm::cast(Ty)); } } else if (!Classes.empty()) { DE.add(Ty); } }; - void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source) { - auto S = findInstanceContext(Ty, Class, Source); + void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty) { + auto S = findInstanceContext(Ty, Class); for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) { - propagateClasses(Classes, Arg, Source); + propagateClasses(Classes, Arg); } }; - class ArrowCursor { - - std::stack> Path; - - public: - - ArrowCursor(TArrow* Arr) { - Path.push({ Arr, 0 }); - } - - Type* next() { - while (!Path.empty()) { - auto& [Arr, I] = Path.top(); - Type* Ty; - if (I == -1) { - Path.pop(); - continue; - } - if (I == Arr->ParamTypes.size()) { - I = -1; - Ty = Arr->ReturnType; - } else { - Ty = Arr->ParamTypes[I]; - I++; - } - if (llvm::isa(Ty)) { - Path.push({ static_cast(Ty), 0 }); - } else { - return Ty; - } - } - return nullptr; - } - - }; - void Checker::solveCEqual(CEqual* C) { - std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; - if (!unify(C->Left, C->Right, C->Source)) { - DE.add(simplify(C->Left), simplify(C->Right), C->Source); - } + /* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */ + OrigLeft = C->Left; + OrigRight = C->Right; + Source = C->Source; + unify(C->Left, C->Right); + LeftPath = {}; + RightPath = {}; + /* DE.add(simplify(C->Left), simplify(C->Right), C->Source); */ } Type* Checker::simplify(Type* Ty) { @@ -1314,11 +1123,11 @@ namespace bolt { return Ty; } - void Checker::join(TVar* TV, Type* Ty, Node* Source) { + void Checker::join(TVar* TV, Type* Ty) { Solution[TV] = Ty; - propagateClasses(TV->Contexts, Ty, Source); + propagateClasses(TV->Contexts, Ty); // This is a very specific adjustment that is critical to the // well-functioning of the infer/unify algorithm. When addConstraint() is @@ -1335,24 +1144,62 @@ namespace bolt { } - bool Checker::unify(Type* A, Type* B, Node* Source) { + void Checker::unifyError() { + DE.add( + simplify(OrigLeft), + simplify(OrigRight), + LeftPath, + RightPath, + Source + ); + } - auto find = [&](auto OrigTy) { - auto Ty = OrigTy; - if (llvm::isa(Ty)) { - auto TV = static_cast(Ty); - do { - auto Match = Solution.find(static_cast(Ty)); - if (Match == Solution.end()) { - break; - } - Ty = Match->second; - } while (Ty->getKind() == TypeKind::Var); - // FIXME does this actually improove performance? - Solution[TV] = Ty; + class ArrowCursor { + + std::stack> Stack; + TypePath& Path; + std::size_t I; + + public: + + ArrowCursor(TArrow* Arr, TypePath& Path): + Path(Path) { + Stack.push({ Arr, true }); + Path.push_back(Arr->getStartIndex()); } - return Ty; - }; + + Type* next() { + while (!Stack.empty()) { + auto& [Arrow, First] = Stack.top(); + auto& Index = Path.back(); + if (!First) { + Index.advance(Arrow); + } else { + First = false; + } + Type* Ty; + if (Index == Arrow->getEndIndex()) { + Path.pop_back(); + Stack.pop(); + continue; + } + Ty = Arrow->resolve(Index); + if (llvm::isa(Ty)) { + auto NewIndex = Arrow->getStartIndex(); + Stack.push({ static_cast(Ty), true }); + Path.push_back(NewIndex); + } else { + return Ty; + } + } + return nullptr; + } + + }; + + + + bool Checker::unify(Type* A, Type* B) { A = simplify(A); B = simplify(B); @@ -1362,6 +1209,7 @@ namespace bolt { auto Var2 = static_cast(B); if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) { if (Var1->Id != Var2->Id) { + unifyError(); return false; } return true; @@ -1373,38 +1221,47 @@ namespace bolt { From = Var2; } else { // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification - // Either way, Var1 is a good candidate for being unified away + // Either way, Var1, being Unification, is a good candidate for being unified away To = Var2; From = Var1; } - join(From, To, Source); - propagateClasses(From->Contexts, To, Source); + join(From, To); return true; } if (llvm::isa(A)) { + auto TV = static_cast(A); + + // Rigid type variables can never unify with antything else than what we + // have already handled in the previous if-statement, so issue an error. if (TV->getVarKind() == VarKind::Rigid) { + unifyError(); return false; } + // Occurs check if (B->hasTypeVar(TV)) { // NOTE Just like GHC, we just display an error message indicating that // A cannot match B, e.g. a cannot match [a]. It looks much better // than obsure references to an occurs check + unifyError(); return false; } - join(TV, B, Source); + + join(TV, B); + return true; } if (llvm::isa(B)) { - return unify(B, A, Source); + return unify(B, A); } if (llvm::isa(A) && llvm::isa(B)) { - auto C1 = ArrowCursor(static_cast(A)); - auto C2 = ArrowCursor(static_cast(B)); + auto C1 = ArrowCursor(static_cast(A), LeftPath); + auto C2 = ArrowCursor(static_cast(B), RightPath); + bool Success = true; for (;;) { auto T1 = C1.next(); auto T2 = C2.next(); @@ -1412,13 +1269,15 @@ namespace bolt { break; } if (T1 == nullptr || T2 == nullptr) { - return false; + unifyError(); + Success = false; + break; } - if (!unify(T1, T2, Source)) { - return false; + if (!unify(T1, T2)) { + Success = false; } } - return true; + return Success; /* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */ /* return false; */ /* } */ @@ -1434,26 +1293,31 @@ namespace bolt { if (llvm::isa(A)) { auto Arr = static_cast(A); if (Arr->ParamTypes.empty()) { - return unify(Arr->ReturnType, B, Source); + return unify(Arr->ReturnType, B); } } if (llvm::isa(B)) { - return unify(B, A, Source); + return unify(B, A); } if (llvm::isa(A) && llvm::isa(B)) { auto Tuple1 = static_cast(A); auto Tuple2 = static_cast(B); if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) { + unifyError(); return false; } auto Count = Tuple1->ElementTypes.size(); bool Success = true; for (size_t I = 0; I < Count; I++) { - if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], Source)) { + LeftPath.push_back(TypeIndex::forTupleElement(I)); + RightPath.push_back(TypeIndex::forTupleElement(I)); + if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I])) { Success = false; } + LeftPath.pop_back(); + RightPath.pop_back(); } return Success; } @@ -1468,18 +1332,25 @@ namespace bolt { auto Con1 = static_cast(A); auto Con2 = static_cast(B); if (Con1->Id != Con2->Id) { + unifyError(); return false; } ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); auto Count = Con1->Args.size(); + bool Success = true; for (std::size_t I = 0; I < Count; I++) { - if (!unify(Con1->Args[I], Con2->Args[I], Source)) { - return false; + LeftPath.push_back(TypeIndex::forConArg(I)); + RightPath.push_back(TypeIndex::forConArg(I)); + if (!unify(Con1->Args[I], Con2->Args[I])) { + Success = false; } + LeftPath.pop_back(); + RightPath.pop_back(); } - return true; + return Success; } + unifyError(); return false; } diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 8844df560..d26a4c485 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -438,14 +438,29 @@ namespace bolt { setBold(true); Out << "error: "; resetStyles(); - Out << "the types " << ANSI_FG_GREEN << describe(E.Left) << ANSI_RESET - << " and " << ANSI_FG_GREEN << describe(E.Right) << ANSI_RESET << " failed to match\n\n"; + auto Left = E.Left->resolve(E.LeftPath); + auto Right = E.Right->resolve(E.RightPath); + Out << "the types " << ANSI_FG_GREEN << describe(Left) << ANSI_RESET + << " and " << ANSI_FG_GREEN << describe(Right) << ANSI_RESET << " failed to match\n\n"; if (E.Source) { auto Range = E.Source->getRange(); - //std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n"; writeExcerpt(E.Source->getSourceFile()->getTextFile(), Range, Range, Color::Red); Out << "\n"; } + if (!E.LeftPath.empty()) { + setForegroundColor(Color::Yellow); + setBold(true); + Out << " info: "; + resetStyles(); + Out << "type " << ANSI_FG_GREEN << describe(Left) << ANSI_RESET << " occurs in the full type " << ANSI_FG_GREEN << describe(E.Left) << ANSI_RESET << "\n\n"; + } + if (!E.RightPath.empty()) { + setForegroundColor(Color::Yellow); + setBold(true); + Out << " info: "; + resetStyles(); + Out << "type " << ANSI_FG_GREEN << describe(Right) << ANSI_RESET << " occurs in the full type " << ANSI_FG_GREEN << describe(E.Right) << ANSI_RESET << "\n\n"; + } return; } diff --git a/src/Types.cc b/src/Types.cc new file mode 100644 index 000000000..10ab6e089 --- /dev/null +++ b/src/Types.cc @@ -0,0 +1,369 @@ + +#include "llvm/Support/Casting.h" + +#include "zen/config.hpp" +#include "zen/range.hpp" + +#include "bolt/Type.hpp" + +namespace bolt { + + bool TypeclassSignature::operator<(const TypeclassSignature& Other) const { + if (Id < Other.Id) { + return true; + } + ZEN_ASSERT(Params.size() == 1); + ZEN_ASSERT(Other.Params.size() == 1); + return Params[0]->Id < Other.Params[0]->Id; + } + + bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { + ZEN_ASSERT(Params.size() == 1); + ZEN_ASSERT(Other.Params.size() == 1); + return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id; + } + + bool TypeIndex::operator==(const TypeIndex& Other) const noexcept { + if (Kind != Other.Kind) { + return false; + } + switch (Kind) { + case TypeIndexKind::ConArg: + case TypeIndexKind::ArrowParamType: + case TypeIndexKind::TupleElement: + return I == Other.I; + default: + return true; + } + } + + void TypeIndex::advance(const Type* Ty) { + switch (Kind) { + case TypeIndexKind::End: + break; + case TypeIndexKind::ArrowParamType: + { + auto Arrow = llvm::cast(Ty); + if (I+1 < Arrow->ParamTypes.size()) { + ++I; + } else { + Kind = TypeIndexKind::ArrowReturnType; + } + break; + } + case TypeIndexKind::ArrowReturnType: + Kind = TypeIndexKind::End; + break; + case TypeIndexKind::ConArg: + { + auto Con = llvm::cast(Ty); + if (I+1 < Con->Args.size()) { + ++I; + } else { + Kind = TypeIndexKind::End; + } + break; + } + case TypeIndexKind::TupleElement: + { + auto Tuple = llvm::cast(Ty); + if (I+1 < Tuple->ElementTypes.size()) { + ++I; + } else { + Kind = TypeIndexKind::End; + } + break; + } + } + } + + void Type::addTypeVars(TVSet& TVs) { + switch (Kind) { + case TypeKind::Var: + TVs.emplace(static_cast(this)); + break; + case TypeKind::Arrow: + { + auto Arrow = static_cast(this); + for (auto Ty: Arrow->ParamTypes) { + Ty->addTypeVars(TVs); + } + Arrow->ReturnType->addTypeVars(TVs); + break; + } + case TypeKind::Con: + { + auto Con = static_cast(this); + for (auto Ty: Con->Args) { + Ty->addTypeVars(TVs); + } + break; + } + case TypeKind::TupleIndex: + { + auto Index = static_cast(this); + Index->Ty->addTypeVars(TVs); + break; + } + case TypeKind::Tuple: + { + auto Tuple = static_cast(this); + for (auto Ty: Tuple->ElementTypes) { + Ty->addTypeVars(TVs); + } + break; + } + } + } + + bool Type::hasTypeVar(const TVar* TV) { + switch (Kind) { + case TypeKind::Var: + return static_cast(this)->Id == TV->Id; + case TypeKind::Arrow: + { + auto Arrow = static_cast(this); + for (auto Ty: Arrow->ParamTypes) { + if (Ty->hasTypeVar(TV)) { + return true; + } + } + return Arrow->ReturnType->hasTypeVar(TV); + } + case TypeKind::Con: + { + auto Con = static_cast(this); + for (auto Ty: Con->Args) { + if (Ty->hasTypeVar(TV)) { + return true; + } + } + return false; + } + case TypeKind::TupleIndex: + { + auto Index = static_cast(this); + return Index->Ty->hasTypeVar(TV); + } + case TypeKind::Tuple: + { + auto Tuple = static_cast(this); + for (auto Ty: Tuple->ElementTypes) { + if (Ty->hasTypeVar(TV)) { + return true; + } + } + return false; + } + } + } + + Type* Type::substitute(const TVSub &Sub) { + switch (Kind) { + case TypeKind::Var: + { + auto TV = static_cast(this); + auto Match = Sub.find(TV); + return Match != Sub.end() ? Match->second->substitute(Sub) : this; + } + case TypeKind::Arrow: + { + auto Arrow = static_cast(this); + bool Changed = false; + std::vector NewParamTypes; + for (auto Ty: Arrow->ParamTypes) { + auto NewParamType = Ty->substitute(Sub); + if (NewParamType != Ty) { + Changed = true; + } + NewParamTypes.push_back(NewParamType); + } + auto NewRetTy = Arrow->ReturnType->substitute(Sub) ; + if (NewRetTy != Arrow->ReturnType) { + Changed = true; + } + return Changed ? new TArrow(NewParamTypes, NewRetTy) : this; + } + case TypeKind::Con: + { + auto Con = static_cast(this); + bool Changed = false; + std::vector NewArgs; + for (auto Arg: Con->Args) { + auto NewArg = Arg->substitute(Sub); + if (NewArg != Arg) { + Changed = true; + } + NewArgs.push_back(NewArg); + } + return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this; + } + case TypeKind::TupleIndex: + { + auto Tuple = static_cast(this); + auto NewTy = Tuple->Ty->substitute(Sub); + return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple; + } + case TypeKind::Tuple: + { + auto Tuple = static_cast(this); + bool Changed = false; + std::vector NewElementTypes; + for (auto Ty: Tuple->ElementTypes) { + auto NewElementType = Ty->substitute(Sub); + if (NewElementType != Ty) { + Changed = true; + } + NewElementTypes.push_back(NewElementType); + } + return Changed ? new TTuple(NewElementTypes) : this; + } + } + } + + Type* Type::resolve(const TypeIndex& Index) const noexcept { + switch (Index.Kind) { + case TypeIndexKind::ConArg: + return llvm::cast(this)->Args[Index.I]; + case TypeIndexKind::TupleElement: + return llvm::cast(this)->ElementTypes[Index.I]; + case TypeIndexKind::ArrowParamType: + return llvm::cast(this)->ParamTypes[Index.I]; + case TypeIndexKind::ArrowReturnType: + return llvm::cast(this)->ReturnType; + case TypeIndexKind::End: + ZEN_UNREACHABLE + } + ZEN_UNREACHABLE + } + + bool Type::operator==(const Type& Other) const noexcept { + switch (Kind) { + case TypeKind::Var: + if (Other.Kind != TypeKind::Var) { + return false; + } + return static_cast(this)->Id == static_cast(Other).Id; + case TypeKind::Tuple: + { + if (Other.Kind != TypeKind::Tuple) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + if (A.ElementTypes.size() != B.ElementTypes.size()) { + return false; + } + for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) { + if (*T1 != *T2) { + return false; + } + } + return true; + } + case TypeKind::TupleIndex: + { + if (Other.Kind != TypeKind::TupleIndex) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + return A.I == B.I && *A.Ty == *B.Ty; + } + case TypeKind::Con: + { + if (Other.Kind != TypeKind::Con) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + if (A.Id != B.Id) { + return false; + } + if (A.Args.size() != B.Args.size()) { + return false; + } + for (auto [T1, T2]: zen::zip(A.Args, B.Args)) { + if (*T1 != *T2) { + return false; + } + } + return true; + } + case TypeKind::Arrow: + { + // FIXME Do we really need to 'curry' this type? + if (Other.Kind != TypeKind::Arrow) { + return false; + } + auto A = static_cast(*this); + auto B = static_cast(Other); + /* ArrowCursor C1 { &A }; */ + /* ArrowCursor C2 { &B }; */ + /* for (;;) { */ + /* auto T1 = C1.next(); */ + /* auto T2 = C2.next(); */ + /* if (T1 == nullptr && T2 == nullptr) { */ + /* break; */ + /* } */ + /* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */ + /* return false; */ + /* } */ + /* } */ + if (A.ParamTypes.size() != B.ParamTypes.size()) { + return false; + } + for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) { + if (*T1 != *T2) { + return false; + } + } + return A.ReturnType != B.ReturnType; + } + } + } + + TypeIterator Type::begin() { + return TypeIterator { this, getStartIndex() }; + } + + TypeIterator Type::end() { + return TypeIterator { this, getEndIndex() }; + } + + TypeIndex Type::getStartIndex() { + switch (Kind) { + case TypeKind::Con: + { + auto Con = static_cast(this); + if (Con->Args.empty()) { + return TypeIndex(TypeIndexKind::End); + } + return TypeIndex::forConArg(0); + } + case TypeKind::Arrow: + { + auto Arrow = static_cast(this); + if (Arrow->ParamTypes.empty()) { + return TypeIndex::forArrowReturnType(); + } + return TypeIndex::forArrowParamType(0); + } + case TypeKind::Tuple: + { + auto Tuple = static_cast(this); + if (Tuple->ElementTypes.empty()) { + return TypeIndex(TypeIndexKind::End); + } + return TypeIndex::forTupleElement(0); + } + default: + return TypeIndex(TypeIndexKind::End); + } + } + + TypeIndex Type::getEndIndex() { + return TypeIndex(TypeIndexKind::End); + } + +}