diff --git a/bootstrap/cxx/include/bolt/CST.hpp b/bootstrap/cxx/include/bolt/CST.hpp index 4a766457d..fda3ef6e8 100644 --- a/bootstrap/cxx/include/bolt/CST.hpp +++ b/bootstrap/cxx/include/bolt/CST.hpp @@ -184,6 +184,12 @@ namespace bolt { template NodeKind getNodeType(); + enum NodeFlags { + NodeFlags_TypeIsSolved = 1, + }; + + using NodeFlagsMask = unsigned; + class Node { unsigned RefCount = 1; @@ -192,6 +198,7 @@ namespace bolt { public: + NodeFlagsMask Flags = 0; Node* Parent = nullptr; inline void ref() { diff --git a/bootstrap/cxx/include/bolt/Checker.hpp b/bootstrap/cxx/include/bolt/Checker.hpp index e41a528d0..2c750d670 100644 --- a/bootstrap/cxx/include/bolt/Checker.hpp +++ b/bootstrap/cxx/include/bolt/Checker.hpp @@ -178,6 +178,7 @@ namespace bolt { Type* ListType; Type* IntType; Type* StringType; + Type* UnitType; Graph RefGraph; @@ -217,9 +218,9 @@ namespace bolt { /// Factory methods - TCon* createConType(ByteString Name); - TVar* createTypeVar(); - TVarRigid* createRigidVar(ByteString Name); + Type* createConType(ByteString Name); + Type* createTypeVar(); + Type* createRigidVar(ByteString Name); InferContext* createInferContext( InferContext* Parent = nullptr, TVSet* TVs = new TVSet, @@ -280,6 +281,11 @@ namespace bolt { */ Type* simplifyType(Type* Ty); + /** + * \internal + */ + Type* solveType(Type* Ty); + void check(SourceFile* SF); inline Type* getBoolType() const { diff --git a/bootstrap/cxx/include/bolt/DiagnosticEngine.hpp b/bootstrap/cxx/include/bolt/DiagnosticEngine.hpp index d7018c319..867b35844 100644 --- a/bootstrap/cxx/include/bolt/DiagnosticEngine.hpp +++ b/bootstrap/cxx/include/bolt/DiagnosticEngine.hpp @@ -22,12 +22,17 @@ namespace bolt { public: + bool FailOnError = false; + inline bool hasError() const noexcept { return HasError; } template void add(Ts&&... Args) { + // if (FailOnError) { + // ZEN_PANIC("An error diagnostic caused the program to abort."); + // } HasError = true; addDiagnostic(new D { std::forward(Args)... }); } diff --git a/bootstrap/cxx/include/bolt/Diagnostics.hpp b/bootstrap/cxx/include/bolt/Diagnostics.hpp index 8e2067da0..3d44ace13 100644 --- a/bootstrap/cxx/include/bolt/Diagnostics.hpp +++ b/bootstrap/cxx/include/bolt/Diagnostics.hpp @@ -2,7 +2,6 @@ #pragma once #include -#include #include "bolt/ByteString.hpp" #include "bolt/String.hpp" @@ -12,15 +11,16 @@ namespace bolt { enum class DiagnosticKind : unsigned char { - UnexpectedToken, - UnexpectedString, BindingNotFound, - UnificationError, - TypeclassMissing, - InstanceNotFound, - TupleIndexOutOfRange, - InvalidTypeToTypeclass, FieldNotFound, + InstanceNotFound, + InvalidTypeToTypeclass, + NotATuple, + TupleIndexOutOfRange, + TypeclassMissing, + UnexpectedString, + UnexpectedToken, + UnificationError, }; class Diagnostic : std::runtime_error { @@ -168,10 +168,10 @@ namespace bolt { class TupleIndexOutOfRangeDiagnostic : public Diagnostic { public: - TTuple* Tuple; + Type* Tuple; std::size_t I; - inline TupleIndexOutOfRangeDiagnostic(TTuple* Tuple, std::size_t I): + inline TupleIndexOutOfRangeDiagnostic(Type* Tuple, std::size_t I): Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I) {} unsigned getCode() const noexcept override { @@ -209,7 +209,7 @@ namespace bolt { Node* Source; inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source): - Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {} + Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {} unsigned getCode() const noexcept override { return 2017; @@ -217,4 +217,18 @@ namespace bolt { }; + class NotATupleDiagnostic : public Diagnostic { + public: + + Type* Ty; + + inline NotATupleDiagnostic(Type* Ty): + Diagnostic(DiagnosticKind::NotATuple), Ty(Ty) {} + + unsigned getCode() const noexcept override { + return 2016; + } + + }; + } diff --git a/bootstrap/cxx/include/bolt/Type.hpp b/bootstrap/cxx/include/bolt/Type.hpp index ab5ae8b71..a123ce30f 100644 --- a/bootstrap/cxx/include/bolt/Type.hpp +++ b/bootstrap/cxx/include/bolt/Type.hpp @@ -2,17 +2,21 @@ #pragma once #include -#include -#include -#include +#include +#include #include +#include +#include +#include "zen/config.hpp" +#include "zen/range.hpp" + +#include "bolt/CST.hpp" #include "bolt/ByteString.hpp" namespace bolt { class Type; - class TVar; class TCon; using TypeclassId = ByteString; @@ -23,7 +27,7 @@ namespace bolt { using TypeclassId = ByteString; TypeclassId Id; - std::vector Params; + std::vector Params; bool operator<(const TypeclassSignature& Other) const; bool operator==(const TypeclassSignature& Other) const; @@ -144,8 +148,8 @@ namespace bolt { using TypePath = std::vector; - using TVSub = std::unordered_map; - using TVSet = std::unordered_set; + using TVSub = std::unordered_map; + using TVSet = std::unordered_set; enum class TypeKind : unsigned char { Var, @@ -160,48 +164,402 @@ namespace bolt { Present, }; - class Type { + class Type; - const TypeKind Kind; + struct TCon { + size_t Id; + ByteString DisplayName; - protected: + bool operator==(const TCon& Other) const; - inline Type(TypeKind Kind): - Kind(Kind) {} + }; - public: + struct TApp { + Type* Op; + Type* Arg; - inline TypeKind getKind() const noexcept { + bool operator==(const TApp& Other) const; + + }; + + enum class VarKind { + Rigid, + Unification, + }; + + struct TVar { + VarKind VK; + size_t Id; + TypeclassContext Context; + std::optional Name; + std::optional Provided; + + VarKind getKind() const { + return VK; + } + + bool isUni() const { + return VK == VarKind::Unification; + } + + bool isRigid() const { + return VK == VarKind::Rigid; + } + + bool operator==(const TVar& Other) const; + + }; + + struct TArrow { + Type* ParamType; + Type* ReturnType; + + bool operator==(const TArrow& Other) const; + + }; + + struct TTuple { + std::vector ElementTypes; + + bool operator==(const TTuple& Other) const; + + }; + + struct TTupleIndex { + + Type* Ty; + std::size_t I; + + bool operator==(const TTupleIndex& Other) const; + + }; + + struct TNil { + bool operator==(const TNil& Other) const; + }; + + struct TField { + ByteString Name; + Type* Ty; + Type* RestTy; + bool operator==(const TField& Other) const; + }; + + struct TAbsent { + bool operator==(const TAbsent& Other) const; + }; + + struct TPresent { + Type* Ty; + bool operator==(const TPresent& Other) const; + }; + + struct Type { + + TypeKind Kind; + + Type* Parent = this; + + union { + TCon Con; + TApp App; + TVar Var; + TArrow Arrow; + TTuple Tuple; + TTupleIndex TupleIndex; + TNil Nil; + TField Field; + TAbsent Absent; + TPresent Present; + }; + + Type(TCon&& Con): + Kind(TypeKind::Con), Con(std::move(Con)) {}; + + Type(TApp&& App): + Kind(TypeKind::App), App(std::move(App)) {}; + + Type(TVar&& Var): + Kind(TypeKind::Var), Var(std::move(Var)) {}; + + Type(TArrow&& Arrow): + Kind(TypeKind::Arrow), Arrow(std::move(Arrow)) {}; + + Type(TTuple&& Tuple): + Kind(TypeKind::Tuple), Tuple(std::move(Tuple)) {}; + + Type(TTupleIndex&& TupleIndex): + Kind(TypeKind::TupleIndex), TupleIndex(std::move(TupleIndex)) {}; + + Type(TNil&& Nil): + Kind(TypeKind::Nil), Nil(std::move(Nil)) {}; + + Type(TField&& Field): + Kind(TypeKind::Field), Field(std::move(Field)) {}; + + Type(TAbsent&& Absent): + Kind(TypeKind::Absent), Absent(std::move(Absent)) {}; + + Type(TPresent&& Present): + Kind(TypeKind::Present), Present(std::move(Present)) {}; + +// Type(TCon Con): Kind(TypeKind::Con) { +// new (&Con)TCon(Con); +// } + +// Type(TApp App): Kind(TypeKind::App) { +// new (&App)TApp(App); +// } + +// Type(TVar Var): Kind(TypeKind::Var) { +// new (&Var)TVar(Var); +// } + +// Type(TArrow Arrow): Kind(TypeKind::Arrow) { +// new (&Arrow)TArrow(Arrow); +// } + +// Type(TTuple Tuple): Kind(TypeKind::Tuple) { +// new (&Tuple)TTuple(Tuple); +// } + +// Type(TTupleIndex TupleIndex): Kind(TypeKind::TupleIndex) { +// new (&TupleIndex)TTupleIndex(TupleIndex); +// } + +// Type(TNil Nil): Kind(TypeKind::Nil) { +// new (&Nil)TNil(Nil); +// } + +// Type(TField Field): Kind(TypeKind::Field) { +// new (&Field)TField(Field); +// } + +// Type(TAbsent Absent): Kind(TypeKind::Absent) { +// new (&Absent)TAbsent(Absent); +// } + +// Type(TPresent Present): Kind(TypeKind::Present) { +// new (&Present)TPresent(Present); +// } + + Type(const Type& Other): Kind(Other.Kind) { + switch (Kind) { + case TypeKind::Con: + new (&Con)TCon(Other.Con); + break; + case TypeKind::App: + new (&App)TApp(Other.App); + break; + case TypeKind::Var: + new (&Var)TVar(Other.Var); + break; + case TypeKind::Arrow: + new (&Arrow)TArrow(Other.Arrow); + break; + case TypeKind::Tuple: + new (&Tuple)TTuple(Other.Tuple); + break; + case TypeKind::TupleIndex: + new (&TupleIndex)TTupleIndex(Other.TupleIndex); + break; + case TypeKind::Nil: + new (&Nil)TNil(Other.Nil); + break; + case TypeKind::Field: + new (&Field)TField(Other.Field); + break; + case TypeKind::Absent: + new (&Absent)TAbsent(Other.Absent); + break; + case TypeKind::Present: + new (&Present)TPresent(Other.Present); + break; + } + } + + Type(Type&& Other): Kind(std::move(Other.Kind)) { + switch (Kind) { + case TypeKind::Con: + new (&Con)TCon(std::move(Other.Con)); + break; + case TypeKind::App: + new (&App)TApp(std::move(Other.App)); + break; + case TypeKind::Var: + new (&Var)TVar(std::move(Other.Var)); + break; + case TypeKind::Arrow: + new (&Arrow)TArrow(std::move(Other.Arrow)); + break; + case TypeKind::Tuple: + new (&Tuple)TTuple(std::move(Other.Tuple)); + break; + case TypeKind::TupleIndex: + new (&TupleIndex)TTupleIndex(std::move(Other.TupleIndex)); + break; + case TypeKind::Nil: + new (&Nil)TNil(std::move(Other.Nil)); + break; + case TypeKind::Field: + new (&Field)TField(std::move(Other.Field)); + break; + case TypeKind::Absent: + new (&Absent)TAbsent(std::move(Other.Absent)); + break; + case TypeKind::Present: + new (&Present)TPresent(std::move(Other.Present)); + break; + } + } + + TypeKind getKind() const { return Kind; } - bool hasTypeVar(const TVar* TV); - - void addTypeVars(TVSet& TVs); - - inline TVSet getTypeVars() { - TVSet Out; - addTypeVars(Out); - return Out; + bool isVarRigid() const { + return Kind == TypeKind::Var + && asVar().getKind() == VarKind::Rigid; } - /** - * Rewrites the entire substructure of a type to another one. - * - * \param Recursive If true, a succesfull local rewritten type will be again - * rewriten until it encounters some terminals. - */ - Type* rewrite(std::function Fn, bool Recursive = false); + bool isVar() const { + return Kind == TypeKind::Var; + } - Type* substitute(const TVSub& Sub); + TVar& asVar() { + ZEN_ASSERT(Kind == TypeKind::Var); + return Var; + } - Type* solve(); + const TVar& asVar() const { + ZEN_ASSERT(Kind == TypeKind::Var); + return Var; + } - TypeIterator begin(); - TypeIterator end(); + bool isApp() const { + return Kind == TypeKind::App; + } - TypeIndex getStartIndex(); - TypeIndex getEndIndex(); + TApp& asApp() { + ZEN_ASSERT(Kind == TypeKind::App); + return App; + } + + const TApp& asApp() const { + ZEN_ASSERT(Kind == TypeKind::App); + return App; + } + + bool isCon() const { + return Kind == TypeKind::Con; + } + + TCon& asCon() { + ZEN_ASSERT(Kind == TypeKind::Con); + return Con; + } + + const TCon& asCon() const { + ZEN_ASSERT(Kind == TypeKind::Con); + return Con; + } + + bool isArrow() const { + return Kind == TypeKind::Arrow; + } + + TArrow& asArrow() { + ZEN_ASSERT(Kind == TypeKind::Arrow); + return Arrow; + } + + const TArrow& asArrow() const { + ZEN_ASSERT(Kind == TypeKind::Arrow); + return Arrow; + } + + bool isTuple() const { + return Kind == TypeKind::Tuple; + } + + TTuple& asTuple() { + ZEN_ASSERT(Kind == TypeKind::Tuple); + return Tuple; + } + + const TTuple& asTuple() const { + ZEN_ASSERT(Kind == TypeKind::Tuple); + return Tuple; + } + + bool isTupleIndex() const { + return Kind == TypeKind::TupleIndex; + } + + TTupleIndex& asTupleIndex() { + ZEN_ASSERT(Kind == TypeKind::TupleIndex); + return TupleIndex; + } + + const TTupleIndex& asTupleIndex() const { + ZEN_ASSERT(Kind == TypeKind::TupleIndex); + return TupleIndex; + } + + bool isField() const { + return Kind == TypeKind::Field; + } + + TField& asField() { + ZEN_ASSERT(Kind == TypeKind::Field); + return Field; + } + + const TField& asField() const { + ZEN_ASSERT(Kind == TypeKind::Field); + return Field; + } + + bool isAbsent() const { + return Kind == TypeKind::Absent; + } + + TAbsent& asAbsent() { + ZEN_ASSERT(Kind == TypeKind::Absent); + return Absent; + } + const TAbsent& asAbsent() const { + ZEN_ASSERT(Kind == TypeKind::Absent); + return Absent; + } + + bool isPresent() const { + return Kind == TypeKind::Present; + } + + TPresent& asPresent() { + ZEN_ASSERT(Kind == TypeKind::Present); + return Present; + } + const TPresent& asPresent() const { + ZEN_ASSERT(Kind == TypeKind::Present); + return Present; + } + + bool isNil() const { + return Kind == TypeKind::Nil; + } + + TNil& asNil() { + ZEN_ASSERT(Kind == TypeKind::Nil); + return Nil; + } + const TNil& asNil() const { + ZEN_ASSERT(Kind == TypeKind::Nil); + return Nil; + } + + Type* rewrite(std::function Fn, bool Recursive = true); Type* resolve(const TypeIndex& Index) const noexcept; @@ -213,207 +571,128 @@ namespace bolt { return Ty; } - bool operator==(const Type& Other) const noexcept; - - bool operator!=(const Type& Other) const noexcept { - return !(*this == Other); + void set(Type* Ty) { + auto Root = find(); + // It is not possible to set a solution twice. + if (isVar()) { + ZEN_ASSERT(Root->isVar()); + } + Root->Parent = Ty; } - }; - - class TCon : public Type { - public: - - const size_t Id; - ByteString DisplayName; - - inline TCon(const size_t Id, ByteString DisplayName): - Type(TypeKind::Con), Id(Id), DisplayName(DisplayName) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Con; + Type* find() const { + Type* Curr = const_cast(this); + for (;;) { + auto Keep = Curr->Parent; + if (Keep == Curr) { + return Keep; + } + Curr->Parent = Keep->Parent; + Curr = Keep; + } } - }; + bool operator==(const Type& Other) const; - class TApp : public Type { - public: - - Type* Op; - Type* Arg; - - inline TApp(Type* Op, Type* Arg): - Type(TypeKind::App), Op(Op), Arg(Arg) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::App; + void destroy() { + switch (Kind) { + case TypeKind::Con: + App.~TApp(); + break; + case TypeKind::App: + App.~TApp(); + break; + case TypeKind::Var: + Var.~TVar(); + break; + case TypeKind::Arrow: + Arrow.~TArrow(); + break; + case TypeKind::Tuple: + Tuple.~TTuple(); + break; + case TypeKind::TupleIndex: + TupleIndex.~TTupleIndex(); + break; + case TypeKind::Nil: + Nil.~TNil(); + break; + case TypeKind::Field: + Field.~TField(); + break; + case TypeKind::Absent: + Absent.~TAbsent(); + break; + case TypeKind::Present: + Present.~TPresent(); + break; + } } - }; - - enum class VarKind { - Rigid, - Unification, - }; - - class TVar : public Type { - - Type* Parent = this; - - public: - - const size_t Id; - VarKind VK; - - TypeclassContext Contexts; - - inline TVar(size_t Id, VarKind VK): - Type(TypeKind::Var), Id(Id), VK(VK) {} - - inline VarKind getVarKind() const noexcept { - return VK; + Type& operator=(Type& Other) { + destroy(); + Kind = Other.Kind; + switch (Kind) { + case TypeKind::Con: + App = Other.App; + break; + case TypeKind::App: + App = Other.App; + break; + case TypeKind::Var: + Var = Other.Var; + break; + case TypeKind::Arrow: + Arrow = Other.Arrow; + break; + case TypeKind::Tuple: + Tuple = Other.Tuple; + break; + case TypeKind::TupleIndex: + TupleIndex = Other.TupleIndex; + break; + case TypeKind::Nil: + Nil = Other.Nil; + break; + case TypeKind::Field: + Field = Other.Field; + break; + case TypeKind::Absent: + Absent = Other.Absent; + break; + case TypeKind::Present: + Present = Other.Present; + break; + } + return *this; } - inline bool isRigid() const noexcept { - return VK == VarKind::Rigid; + bool hasTypeVar(Type* TV) const; + + TypeIterator begin(); + TypeIterator end(); + + TypeIndex getStartIndex() const; + TypeIndex getEndIndex() const; + + Type* substitute(const TVSub& Sub); + + void visitEachChild(std::function Proc); + + TVSet getTypeVars(); + + ~Type() { + destroy(); } - Type* find(); - - void set(Type* Ty); - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Var; - } - - }; - - class TVarRigid : public TVar { - public: - - ByteString Name; - - TypeclassContext Provided; - - inline TVarRigid(size_t Id, ByteString Name): - TVar(Id, VarKind::Rigid), Name(Name) {} - - }; - - class TArrow : public Type { - public: - - Type* ParamType; - Type* ReturnType; - - inline TArrow( - Type* ParamType, - Type* ReturnType - ): Type(TypeKind::Arrow), - ParamType(ParamType), - ReturnType(ReturnType) {} - - static Type* build(std::vector ParamTypes, Type* ReturnType) { + static Type* buildArrow(std::vector ParamTypes, Type* ReturnType) { Type* Curr = ReturnType; for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) { - Curr = new TArrow(*Iter, Curr); + Curr = new Type(TArrow(*Iter, Curr)); } return Curr; } - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Arrow; - } - - }; - - class TTuple : public Type { - public: - - std::vector ElementTypes; - - inline TTuple(std::vector ElementTypes): - Type(TypeKind::Tuple), ElementTypes(ElementTypes) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Tuple; - } - - }; - - class TTupleIndex : public Type { - public: - - Type* Ty; - std::size_t I; - - inline TTupleIndex(Type* Ty, std::size_t I): - Type(TypeKind::TupleIndex), Ty(Ty), I(I) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::TupleIndex; - } - - }; - - class TNil : public Type { - public: - - inline TNil(): - Type(TypeKind::Nil) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Nil; - } - - }; - - class TField : public Type { - public: - - ByteString Name; - Type* Ty; - Type* RestTy; - - inline TField( - ByteString Name, - Type* Ty, - Type* RestTy - ): Type(TypeKind::Field), - Name(Name), - Ty(Ty), - RestTy(RestTy) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Field; - } - - }; - - class TAbsent : public Type { - public: - - inline TAbsent(): - Type(TypeKind::Absent) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Absent; - } - - }; - - class TPresent : public Type { - public: - - Type* Ty; - - inline TPresent(Type* Ty): - Type(TypeKind::Present), Ty(Ty) {} - - static bool classof(const Type* Ty) { - return Ty->getKind() == TypeKind::Present; - } - }; template @@ -426,48 +705,49 @@ namespace bolt { virtual void enterType(C* Ty) {} virtual void exitType(C* Ty) {} - virtual void visitType(C* Ty) { - visitEachChild(Ty); + // virtual void visitType(C* Ty) { + // visitEachChild(Ty); + // } + + virtual void visitVarType(C& Ty) { } - virtual void visitVarType(C* Ty) { - visitType(Ty); + virtual void visitAppType(C& Ty) { + visit(Ty.Op); + visit(Ty.Arg); } - virtual void visitAppType(C* Ty) { - visitType(Ty); + virtual void visitPresentType(C& Ty) { + visit(Ty.Ty); } - virtual void visitPresentType(C* Ty) { - visitType(Ty); + virtual void visitConType(C& Ty) { } - virtual void visitConType(C* Ty) { - visitType(Ty); + virtual void visitArrowType(C& Ty) { + visit(Ty.ParamType); + visit(Ty.ReturnType); } - virtual void visitArrowType(C* Ty) { - visitType(Ty); + virtual void visitTupleType(C& Ty) { + for (auto ElTy: Ty.ElementTypes) { + visit(ElTy); + } } - virtual void visitTupleType(C* Ty) { - visitType(Ty); + virtual void visitTupleIndexType(C& Ty) { + visit(Ty.Ty); } - virtual void visitTupleIndexType(C* Ty) { - visitType(Ty); + virtual void visitAbsentType(C& Ty) { } - virtual void visitAbsentType(C* Ty) { - visitType(Ty); + virtual void visitFieldType(C& Ty) { + visit(Ty.Ty); + visit(Ty.RestTy); } - virtual void visitFieldType(C* Ty) { - visitType(Ty); - } - - virtual void visitNilType(C* Ty) { - visitType(Ty); + virtual void visitNilType(C& Ty) { } public: @@ -481,14 +761,14 @@ namespace bolt { break; case TypeKind::Arrow: { - auto Arrow = static_cast*>(Ty); + auto& Arrow = Ty->asArrow(); visit(Arrow->ParamType); visit(Arrow->ReturnType); break; } case TypeKind::Tuple: { - auto Tuple = static_cast*>(Ty); + auto& Tuple = Ty->asTuple(); for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) { visit(Tuple->ElementTypes[I]); } @@ -496,27 +776,27 @@ namespace bolt { } case TypeKind::App: { - auto App = static_cast*>(Ty); + auto& App = Ty->asApp(); visit(App->Op); visit(App->Arg); break; } case TypeKind::Field: { - auto Field = static_cast*>(Ty); + auto& Field = Ty->asField(); visit(Field->Ty); visit(Field->RestTy); break; } case TypeKind::Present: { - auto Present = static_cast*>(Ty); + auto& Present = Ty->asPresent(); visit(Present->Ty); break; } case TypeKind::TupleIndex: { - auto Index = static_cast*>(Ty); + auto& Index = Ty->asTupleIndex(); visit(Index->Ty); break; } @@ -524,37 +804,41 @@ namespace bolt { } void visit(C* Ty) { + + // Always look at the most solved solution + Ty = Ty->find(); + enterType(Ty); switch (Ty->getKind()) { case TypeKind::Present: - visitPresentType(static_cast*>(Ty)); + visitPresentType(Ty->asPresent()); break; case TypeKind::Absent: - visitAbsentType(static_cast*>(Ty)); + visitAbsentType(Ty->asAbsent()); break; case TypeKind::Nil: - visitNilType(static_cast*>(Ty)); + visitNilType(Ty->asNil()); break; case TypeKind::Field: - visitFieldType(static_cast*>(Ty)); + visitFieldType(Ty->asField()); break; case TypeKind::Con: - visitConType(static_cast*>(Ty)); + visitConType(Ty->asCon()); break; case TypeKind::Arrow: - visitArrowType(static_cast*>(Ty)); + visitArrowType(Ty->asArrow()); break; case TypeKind::Var: - visitVarType(static_cast*>(Ty)); + visitVarType(Ty->asVar()); break; case TypeKind::Tuple: - visitTupleType(static_cast*>(Ty)); + visitTupleType(Ty->asTuple()); break; case TypeKind::App: - visitAppType(static_cast*>(Ty)); + visitAppType(Ty->asApp()); break; case TypeKind::TupleIndex: - visitTupleIndexType(static_cast*>(Ty)); + visitTupleIndexType(Ty->asTupleIndex()); break; } exitType(Ty); @@ -567,11 +851,4 @@ namespace bolt { using TypeVisitor = TypeVisitorBase; using ConstTypeVisitor = TypeVisitorBase; - // template - // struct DerefHash { - // std::size_t operator()(const T& Value) const noexcept { - // return std::hash{}(*Value); - // } - // }; - } diff --git a/bootstrap/cxx/src/Checker.cc b/bootstrap/cxx/src/Checker.cc index 61a2f196d..8980d7ea7 100644 --- a/bootstrap/cxx/src/Checker.cc +++ b/bootstrap/cxx/src/Checker.cc @@ -1,13 +1,11 @@ #include -#include #include #include -#include "bolt/Type.hpp" #include "zen/config.hpp" -#include "zen/range.hpp" +#include "bolt/Type.hpp" #include "bolt/CSTVisitor.hpp" #include "bolt/DiagnosticEngine.hpp" #include "bolt/Diagnostics.hpp" @@ -39,29 +37,30 @@ namespace bolt { Type* Checker::simplifyType(Type* Ty) { - return Ty->rewrite([&](auto Ty) { + Ty = Ty->find(); - if (Ty->getKind() == TypeKind::Var) { - Ty = static_cast(Ty)->find(); - } - - if (Ty->getKind() == TypeKind::TupleIndex) { - auto Index = static_cast(Ty); - auto MaybeTuple = simplifyType(Index->Ty); - if (MaybeTuple->getKind() == TypeKind::Tuple) { - auto Tuple = static_cast(MaybeTuple); - if (Index->I >= Tuple->ElementTypes.size()) { - DE.add(Tuple, Index->I); - } else { - Ty = simplifyType(Tuple->ElementTypes[Index->I]); - } + if (Ty->isTupleIndex()) { + auto& Index = Ty->asTupleIndex(); + auto MaybeTuple = simplifyType(Index.Ty); + if (MaybeTuple->isTuple()) { + auto& Tuple = MaybeTuple->asTuple(); + if (Index.I >= Tuple.ElementTypes.size()) { + DE.add(MaybeTuple, Index.I); + } else { + auto ElementTy = simplifyType(Tuple.ElementTypes[Index.I]); + Ty->set(ElementTy); + Ty = ElementTy; } + } else if (!MaybeTuple->isVar()) { + DE.add(MaybeTuple); } + } - return Ty; - - }, /*Recursive=*/true); + return Ty; + } + Type* Checker::solveType(Type* Ty) { + return Ty->rewrite([this](auto Ty) { return simplifyType(Ty); }, true); } Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): @@ -70,6 +69,7 @@ namespace bolt { IntType = createConType("Int"); StringType = createConType("String"); ListType = createConType("List"); + UnitType = new Type(TTuple({})); } Scheme* Checker::lookup(ByteString Name) { @@ -293,7 +293,7 @@ namespace bolt { setContext(Decl->Ctx); - std::vector Vars; + std::vector Vars; for (auto TE: Decl->TVs) { auto TV = createRigidVar(TE->Name->getCanonicalText()); Decl->Ctx->TVs->emplace(TV); @@ -312,13 +312,20 @@ namespace bolt { auto TupleMember = static_cast(Member); auto RetTy = Ty; for (auto Var: Vars) { - RetTy = new TApp(RetTy, Var); + RetTy = new Type(TApp(RetTy, Var)); } std::vector ParamTypes; for (auto Element: TupleMember->Elements) { ParamTypes.push_back(inferTypeExpression(Element)); } - Decl->Ctx->Parent->add(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, TArrow::build(ParamTypes, RetTy))); + Decl->Ctx->Parent->add( + TupleMember->Name->getCanonicalText(), + new Forall( + Decl->Ctx->TVs, + Decl->Ctx->Constraints, + Type::buildArrow(ParamTypes, RetTy) + ) + ); break; } case NodeKind::RecordVariantDeclarationMember: @@ -342,7 +349,7 @@ namespace bolt { setContext(Decl->Ctx); - std::vector Vars; + std::vector Vars; for (auto TE: Decl->Vars) { auto TV = createRigidVar(TE->Name->getCanonicalText()); Vars.push_back(TV); @@ -355,15 +362,28 @@ namespace bolt { Decl->Ctx->Parent->add(Name, new Forall(Ty)); // Corresponds to the logic of one branch of a VariantDeclarationMember - Type* FieldsTy = new TNil(); + Type* FieldsTy = new Type(TNil()); for (auto Field: Decl->Fields) { - FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy); + FieldsTy = new Type( + TField( + Field->Name->getCanonicalText(), + new Type(TPresent(inferTypeExpression(Field->TypeExpression))), + FieldsTy + ) + ); } Type* RetTy = Ty; for (auto TV: Vars) { - RetTy = new TApp(RetTy, TV); + RetTy = new Type(TApp(RetTy, TV)); } - Decl->Ctx->Parent->add(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(FieldsTy, RetTy))); + Decl->Ctx->Parent->add( + Name, + new Forall( + Decl->Ctx->TVs, + Decl->Ctx->Constraints, + new Type(TArrow(FieldsTy, RetTy)) + ) + ); popContext(); break; @@ -444,11 +464,11 @@ namespace bolt { auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { auto Id = Class->Name->getCanonicalText(); auto Ctx = &getContext(); - std::vector Out; + std::vector Out; for (auto TE: Class->TypeVars) { auto Name = TE->Name->getCanonicalText(); auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); - TV->Contexts.emplace(Id); + TV->asVar().Context.emplace(Id); Ctx->add(Name, new Forall(TV)); Out.push_back(TV); } @@ -586,7 +606,7 @@ namespace bolt { RetType = createTypeVar(); } - makeEqual(Decl->getType(), TArrow::build(ParamTypes, RetType), Decl); + makeEqual(Decl->getType(), Type::buildArrow(ParamTypes, RetType), Decl); setContext(OldCtx); } @@ -648,8 +668,8 @@ namespace bolt { if (RetStmt->Expression) { makeEqual(inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression); } else { - ReturnType = new TTuple({}); - makeEqual(new TTuple({}), getReturnType(), N); + ReturnType = UnitType; + makeEqual(UnitType, getReturnType(), N); } break; } @@ -691,18 +711,18 @@ namespace bolt { } - TCon* Checker::createConType(ByteString Name) { - return new TCon(NextConTypeId++, Name); + Type* Checker::createConType(ByteString Name) { + return new Type(TCon(NextConTypeId++, Name)); } - TVarRigid* Checker::createRigidVar(ByteString Name) { - auto TV = new TVarRigid(NextTypeVarId++, Name); + Type* Checker::createRigidVar(ByteString Name) { + auto TV = new Type(TVar(VarKind::Rigid, NextTypeVarId++, {}, Name, {{}})); getContext().TVs->emplace(TV); return TV; } - TVar* Checker::createTypeVar() { - auto TV = new TVar(NextTypeVarId++, VarKind::Unification); + Type* Checker::createTypeVar() { + auto TV = new Type(TVar(VarKind::Unification, NextTypeVarId++, {})); getContext().TVs->emplace(TV); return TV; } @@ -727,7 +747,7 @@ namespace bolt { for (auto TV: *F->TVs) { auto Fresh = createTypeVar(); // std::cerr << describe(TV) << " => " << describe(Fresh) << std::endl; - Fresh->Contexts = TV->Contexts; + Fresh->asVar().Context = TV->asVar().Context; Sub[TV] = Fresh; } @@ -736,8 +756,8 @@ namespace bolt { // FIXME improve this if (Constraint->getKind() == ConstraintKind::Equal) { auto Eq = static_cast(Constraint); - Eq->Left = simplifyType(Eq->Left); - Eq->Right = simplifyType(Eq->Right); + Eq->Left = solveType(Eq->Left); + Eq->Right = solveType(Eq->Right); } auto NewConstraint = Constraint->substitute(Sub); @@ -752,11 +772,11 @@ namespace bolt { addConstraint(NewConstraint); } - // Note the call to simplify? This is because constraints may have already + // This call to solve happens because constraints may have already // been solved, with some unification variables being erased. To make // sure we instantiate unification variables that are still in use // we solve before substituting. - return simplifyType(F->Type)->substitute(Sub); + return solveType(F->Type)->substitute(Sub); } } @@ -771,10 +791,8 @@ namespace bolt { std::vector Types; for (auto TE: D->TEs) { auto Ty = inferTypeExpression(TE); - ZEN_ASSERT(Ty->getKind() == TypeKind::Var && static_cast(Ty)->isRigid()); - auto TV = static_cast(Ty); - TV->Provided.emplace(D->Name->getCanonicalText()); - Types.push_back(TV); + Ty->asVar().Provided->emplace(D->Name->getCanonicalText()); + Types.push_back(Ty); } break; } @@ -813,7 +831,7 @@ namespace bolt { auto AppTE = static_cast(N); Type* Ty = inferTypeExpression(AppTE->Op, IsPoly); for (auto Arg: AppTE->Args) { - Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly)); + Ty = new Type(TApp(Ty, inferTypeExpression(Arg, IsPoly))); } N->setType(Ty); return Ty; @@ -830,9 +848,9 @@ namespace bolt { Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar(); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); } - ZEN_ASSERT(Ty->getKind() == TypeKind::Var); + ZEN_ASSERT(Ty->isVar()); N->setType(Ty); - return static_cast(Ty); + return Ty; } case NodeKind::TupleTypeExpression: @@ -842,7 +860,7 @@ namespace bolt { for (auto [TE, Comma]: TupleTE->Elements) { ElementTypes.push_back(inferTypeExpression(TE, IsPoly)); } - auto Ty = new TTuple(ElementTypes); + auto Ty = new Type(TTuple(ElementTypes)); N->setType(Ty); return Ty; } @@ -863,7 +881,7 @@ namespace bolt { ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly)); } auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly); - auto Ty = TArrow::build(ParamTypes, ReturnType); + auto Ty = Type::buildArrow(ParamTypes, ReturnType); N->setType(Ty); return Ty; } @@ -886,14 +904,14 @@ namespace bolt { } Type* sortRow(Type* Ty) { - std::map Fields; - while (Ty->getKind() == TypeKind::Field) { - auto Field = static_cast(Ty); - Fields.emplace(Field->Name, Field); - Ty = Field->RestTy; + std::map Fields; + while (Ty->isField()) { + auto& Field = Ty->asField(); + Fields.emplace(Field.Name, Ty); + Ty = Field.RestTy; } for (auto [Name, Field]: Fields) { - Ty = new TField(Name, Field->Ty, Ty); + Ty = new Type(TField(Name, Field->asField().Ty, Ty)); } return Ty; } @@ -930,7 +948,7 @@ namespace bolt { setContext(OldCtx); } if (!Match->Value) { - Ty = new TArrow(ValTy, Ty); + Ty = new Type(TArrow(ValTy, Ty)); } break; } @@ -938,9 +956,13 @@ namespace bolt { case NodeKind::RecordExpression: { auto Record = static_cast(X); - Ty = new TNil(); + Ty = new Type(TNil()); for (auto [Field, Comma]: Record->Fields) { - Ty = new TField(Field->Name->getCanonicalText(), new TPresent(inferExpression(Field->getExpression())), Ty); + Ty = new Type(TField( + Field->Name->getCanonicalText(), + new Type(TPresent(inferExpression(Field->getExpression()))), + Ty + )); } Ty = sortRow(Ty); break; @@ -998,7 +1020,7 @@ namespace bolt { for (auto Arg: Call->Args) { ArgTypes.push_back(inferExpression(Arg)); } - makeEqual(OpTy, TArrow::build(ArgTypes, Ty), X); + makeEqual(OpTy, Type::buildArrow(ArgTypes, Ty), X); break; } @@ -1008,14 +1030,15 @@ namespace bolt { auto Scm = lookup(Infix->Operator->getText()); if (Scm == nullptr) { DE.add(Infix->Operator->getText(), Infix->Operator); - return createTypeVar(); + Ty = createTypeVar(); + break; } auto OpTy = instantiate(Scm, Infix->Operator); Ty = createTypeVar(); std::vector ArgTys; ArgTys.push_back(inferExpression(Infix->Left)); ArgTys.push_back(inferExpression(Infix->Right)); - makeEqual(TArrow::build(ArgTys, Ty), OpTy, X); + makeEqual(Type::buildArrow(ArgTys, Ty), OpTy, X); break; } @@ -1026,7 +1049,7 @@ namespace bolt { for (auto [E, Comma]: Tuple->Elements) { Types.push_back(inferExpression(E)); } - Ty = new TTuple(Types); + Ty = new Type(TTuple(Types)); break; } @@ -1038,7 +1061,7 @@ namespace bolt { case NodeKind::IntegerLiteral: { auto I = static_cast(Member->Name); - Ty = new TTupleIndex(ExprTy, I->getInteger()); + Ty = new Type(TTupleIndex(ExprTy, I->getInteger())); break; } case NodeKind::Identifier: @@ -1046,7 +1069,7 @@ namespace bolt { auto K = static_cast(Member->Name); Ty = createTypeVar(); auto RestTy = createTypeVar(); - makeEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member); + makeEqual(new Type(TField(K->getCanonicalText(), Ty, RestTy)), ExprTy, Member); break; } default: @@ -1102,7 +1125,7 @@ namespace bolt { } auto Ty = instantiate(Scm, P); auto RetTy = createTypeVar(); - makeEqual(Ty, TArrow::build(ParamTypes, RetTy), P); + makeEqual(Ty, Type::buildArrow(ParamTypes, RetTy), P); return RetTy; } @@ -1113,7 +1136,7 @@ namespace bolt { for (auto [Element, Comma]: P->Elements) { ElementTypes.push_back(inferPattern(Element)); } - return new TTuple(ElementTypes); + return new Type(TTuple(ElementTypes)); } case NodeKind::ListPattern: @@ -1123,7 +1146,7 @@ namespace bolt { for (auto [Element, Separator]: P->Elements) { makeEqual(ElementType, inferPattern(Element), P); } - return new TApp(ListType, ElementType); + return new Type(TApp(ListType, ElementType)); } case NodeKind::NestedPattern: @@ -1204,7 +1227,14 @@ namespace bolt { } Type* Checker::getType(TypedNode *Node) { - return Node->getType()->solve(); + auto Ty = Node->getType(); + if (Node->Flags & NodeFlags_TypeIsSolved) { + return Ty; + } + Ty = solveType(Ty); + Node->setType(Ty); + Node->Flags |= NodeFlags_TypeIsSolved; + return Ty; } void Checker::check(SourceFile *SF) { @@ -1217,11 +1247,11 @@ namespace bolt { addBinding("True", new Forall(BoolType)); addBinding("False", new Forall(BoolType)); auto A = createTypeVar(); - addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, TArrow::build({ A, A }, BoolType))); - addBinding("+", new Forall(TArrow::build({ IntType, IntType }, IntType))); - addBinding("-", new Forall(TArrow::build({ IntType, IntType }, IntType))); - addBinding("*", new Forall(TArrow::build({ IntType, IntType }, IntType))); - addBinding("/", new Forall(TArrow::build({ IntType, IntType }, IntType))); + addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType))); + addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); + addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); + addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); + addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); populate(SF); forwardDeclare(SF); auto SCCs = RefGraph.strongconnect(); @@ -1243,6 +1273,27 @@ namespace bolt { ActiveContext = nullptr; solve(new CMany(*SF->Ctx->Constraints)); + + class Visitor : public CSTVisitor { + + Checker& C; + + public: + + Visitor(Checker& C): + C(C) {} + + void visitAnnotation(Annotation* A) { + + } + + void visitExpression(Expression* X) { + C.getType(X); + } + + } V(*this); + + V.visit(SF); } void Checker::solve(Constraint* Constraint) { @@ -1281,10 +1332,10 @@ namespace bolt { } bool assignableTo(Type* A, Type* B) { - if (isa(A) && isa(B)) { - auto Con1 = cast(A); - auto Con2 = cast(B); - if (Con1->Id != Con2-> Id) { + if (A->isCon() && B->isCon()) { + auto& Con1 = A->asCon(); + auto& Con2 = B->asCon(); + if (Con1.Id != Con2.Id) { return false; } return true; @@ -1295,13 +1346,15 @@ namespace bolt { class ArrowCursor { - std::stack> Stack; + /// Types on this stack are guaranteed to be arrow types. + std::stack> Stack; + TypePath& Path; std::size_t I; public: - ArrowCursor(TArrow* Arr, TypePath& Path): + ArrowCursor(Type* Arr, TypePath& Path): Path(Path) { Stack.push({ Arr, true }); Path.push_back(Arr->getStartIndex()); @@ -1323,9 +1376,9 @@ namespace bolt { continue; } Ty = Arrow->resolve(Index); - if (isa(Ty)) { + if (Ty->isArrow()) { auto NewIndex = Arrow->getStartIndex(); - Stack.push({ static_cast(Ty), true }); + Stack.push({ Ty, true }); Path.push_back(NewIndex); } else { return Ty; @@ -1390,40 +1443,36 @@ namespace bolt { } TypeSig getTypeSig(Type* Ty) { - struct Visitor : TypeVisitor { - Type* Op = nullptr; - std::vector Args; - void visitType(Type* Ty) override { - if (!Op) { - Op = Ty; - } else { - Args.push_back(Ty); - } - } - void visitAppType(TApp* Ty) override { - visitEachChild(Ty); + Type* Op = nullptr; + std::vector Args; + std::function Visit = [&](Type* Ty) { + if (Ty->isApp()) { + Visit(Ty->asApp().Op); + Visit(Ty->asApp().Arg); + } else if (!Op) { + Op = Ty; + } else { + Args.push_back(Ty); } }; - Visitor V; - V.visit(Ty); - return TypeSig { Ty, V.Op, V.Args }; + Visit(Ty); + return TypeSig { Ty, Op, Args }; } void propagateClasses(std::unordered_set& Classes, Type* Ty) { - if (isa(Ty)) { - auto TV = cast(Ty); + if (Ty->isVar()) { + auto TV = Ty->asVar(); for (auto Class: Classes) { - TV->Contexts.emplace(Class); + TV.Context.emplace(Class); } - if (TV->isRigid()) { - auto RV = static_cast(Ty); - for (auto Id: RV->Contexts) { - if (!RV->Provided.count(Id)) { - C.DE.add(TypeclassSignature { Id, { RV } }, getSource()); + if (TV.isRigid()) { + for (auto Id: TV.Context) { + if (!TV.Provided->count(Id)) { + C.DE.add(TypeclassSignature { Id, { Ty } }, getSource()); } } } - } else if (isa(Ty) || isa(Ty)) { + } else if (Ty->isCon() || Ty->isApp()) { auto Sig = getTypeSig(Ty); for (auto Class: Classes) { propagateClassTycon(Class, Sig); @@ -1450,13 +1499,13 @@ namespace bolt { * * Other side effects may occur. */ - void join(TVar* TV, Type* Ty) { + void join(Type* TV, Type* Ty) { // std::cerr << describe(TV) << " => " << describe(Ty) << std::endl; TV->set(Ty); - propagateClasses(TV->Contexts, Ty); + propagateClasses(TV->asVar().Context, Ty); // This is a very specific adjustment that is critical to the // well-functioning of the infer/unify algorithm. When addConstraint() is @@ -1480,21 +1529,21 @@ namespace bolt { }; bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) { - if (isa(A) && isa(B)) { + if (A->isAbsent() && B->isAbsent()) { return true; } - if (isa(B)) { + if (B->isAbsent()) { std::swap(A, B); DidSwap = !DidSwap; } - if (isa(A)) { - auto Present = static_cast(B); - C.DE.add(CurrentFieldName, C.simplifyType(getLeft()), LeftPath, getSource()); + if (A->isAbsent()) { + auto& Present = B->asPresent(); + C.DE.add(CurrentFieldName, C.solveType(getLeft()), LeftPath, getSource()); return false; } - auto Present1 = static_cast(A); - auto Present2 = static_cast(B); - return unify(Present1->Ty, Present2->Ty, DidSwap); + auto& Present1 = A->asPresent(); + auto& Present2 = B->asPresent(); + return unify(Present1.Ty, Present2.Ty, DidSwap); }; bool Unifier::unify(Type* A, Type* B, bool DidSwap) { @@ -1504,8 +1553,8 @@ namespace bolt { auto unifyError = [&]() { C.DE.add( - C.simplifyType(Constraint->Left), - C.simplifyType(Constraint->Right), + Constraint->Left, + Constraint->Right, LeftPath, RightPath, Constraint->Source @@ -1549,50 +1598,50 @@ namespace bolt { DidSwap = !DidSwap; }; - if (isa(A) && isa(B)) { - auto Var1 = static_cast(A); - auto Var2 = static_cast(B); - if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) { - if (Var1->Id != Var2->Id) { + if (A->isVar() && B->isVar()) { + auto& Var1 = A->asVar(); + auto& Var2 = B->asVar(); + if (Var1.isRigid() && Var2.isRigid()) { + if (Var1.Id != Var2.Id) { unifyError(); return false; } return true; } - TVar* To; - TVar* From; - if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) { - To = Var1; - From = Var2; + Type* To; + Type* From; + if (Var1.isRigid() && Var2.isUni()) { + To = A; + From = B; } else { // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification // Either way, Var1, being Unification, is a good candidate for being unified away - To = Var2; - From = Var1; + To = B; + From = A; } - if (From->Id != To->Id) { + if (From->asVar().Id != To->asVar().Id) { join(From, To); } return true; } - if (isa(B)) { + if (B->isVar()) { swap(); } - if (isa(A)) { + if (A->isVar()) { - auto TV = static_cast(A); + auto& TV = A->asVar(); // Rigid type variables can never unify with antything else than what we // have already handled in the previous if-statement, so issue an error. - if (TV->getVarKind() == VarKind::Rigid) { + if (TV.isRigid()) { unifyError(); return false; } // Occurs check - if (B->hasTypeVar(TV)) { + if (B->hasTypeVar(A)) { // NOTE Just like GHC, we just display an error message indicating that // A cannot match B, e.g. a cannot match [a]. It looks much better // than obsure references to an occurs check @@ -1600,25 +1649,25 @@ namespace bolt { return false; } - join(TV, B); + join(A, B); return true; } - if (isa(A) && isa(B)) { - auto Arrow1 = static_cast(A); - auto Arrow2 = static_cast(B); + if (A->isArrow() && B->isArrow()) { + auto& Arrow1 = A->asArrow(); + auto& Arrow2 = B->asArrow(); bool Success = true; LeftPath.push_back(TypeIndex::forArrowParamType()); RightPath.push_back(TypeIndex::forArrowParamType()); - if (!unify(Arrow1->ParamType, Arrow2->ParamType, DidSwap)) { + if (!unify(Arrow1.ParamType, Arrow2.ParamType, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); LeftPath.push_back(TypeIndex::forArrowReturnType()); RightPath.push_back(TypeIndex::forArrowReturnType()); - if (!unify(Arrow1->ReturnType, Arrow2->ReturnType, DidSwap)) { + if (!unify(Arrow1.ReturnType, Arrow2.ReturnType, DidSwap)) { Success = false; } LeftPath.pop_back(); @@ -1626,20 +1675,20 @@ namespace bolt { return Success; } - if (isa(A) && isa(B)) { - auto App1 = static_cast(A); - auto App2 = static_cast(B); + if (A->isApp() && B->isApp()) { + auto& App1 = A->asApp(); + auto& App2 = B->asApp(); bool Success = true; LeftPath.push_back(TypeIndex::forAppOpType()); RightPath.push_back(TypeIndex::forAppOpType()); - if (!unify(App1->Op, App2->Op, DidSwap)) { + if (!unify(App1.Op, App2.Op, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); LeftPath.push_back(TypeIndex::forAppArgType()); RightPath.push_back(TypeIndex::forAppArgType()); - if (!unify(App1->Arg, App2->Arg, DidSwap)) { + if (!unify(App1.Arg, App2.Arg, DidSwap)) { Success = false; } LeftPath.pop_back(); @@ -1647,19 +1696,19 @@ namespace bolt { return Success; } - if (isa(A) && isa(B)) { - auto Tuple1 = static_cast(A); - auto Tuple2 = static_cast(B); - if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) { + if (A->isTuple() && B->isTuple()) { + auto& Tuple1 = A->asTuple(); + auto& Tuple2 = B->asTuple(); + if (Tuple1.ElementTypes.size() != Tuple2.ElementTypes.size()) { unifyError(); return false; } - auto Count = Tuple1->ElementTypes.size(); + auto Count = Tuple1.ElementTypes.size(); bool Success = true; for (size_t I = 0; I < Count; I++) { LeftPath.push_back(TypeIndex::forTupleElement(I)); RightPath.push_back(TypeIndex::forTupleElement(I)); - if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], DidSwap)) { + if (!unify(Tuple1.ElementTypes[I], Tuple2.ElementTypes[I], DidSwap)) { Success = false; } LeftPath.pop_back(); @@ -1668,84 +1717,85 @@ namespace bolt { return Success; } - if (isa(A) || isa(B)) { + if (A->isTupleIndex() || B->isTupleIndex()) { // Type(s) could not be simplified at the beginning of this function, // so we have to re-visit the constraint when there is more information. C.Queue.push_back(Constraint); return true; } - // if (isa(A) && isa(B)) { + // This does not work because it ignores the indices + // if (A->isTupleIndex() && B->isTupleIndex()) { // auto Index1 = static_cast(A); // auto Index2 = static_cast(B); // return unify(Index1->Ty, Index2->Ty, Source); // } - if (isa(A) && isa(B)) { - auto Con1 = static_cast(A); - auto Con2 = static_cast(B); - if (Con1->Id != Con2->Id) { + if (A->isCon() && B->isCon()) { + auto& Con1 = A->asCon(); + auto& Con2 = B->asCon(); + if (Con1.Id != Con2.Id) { unifyError(); return false; } return true; } - if (isa(A) && isa(B)) { + if (A->isNil() && B->isNil()) { return true; } - if (isa(A) && isa(B)) { - auto Field1 = static_cast(A); - auto Field2 = static_cast(B); + if (A->isField() && B->isField()) { + auto& Field1 = A->asField(); + auto& Field2 = B->asField(); bool Success = true; - if (Field1->Name == Field2->Name) { + if (Field1.Name == Field2.Name) { LeftPath.push_back(TypeIndex::forFieldType()); RightPath.push_back(TypeIndex::forFieldType()); - CurrentFieldName = Field1->Name; - if (!unifyField(Field1->Ty, Field2->Ty, DidSwap)) { + CurrentFieldName = Field1.Name; + if (!unifyField(Field1.Ty, Field2.Ty, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); LeftPath.push_back(TypeIndex::forFieldRest()); RightPath.push_back(TypeIndex::forFieldRest()); - if (!unify(Field1->RestTy, Field2->RestTy, DidSwap)) { + if (!unify(Field1.RestTy, Field2.RestTy, DidSwap)) { Success = false; } LeftPath.pop_back(); RightPath.pop_back(); return Success; } - auto NewRestTy = new TVar(C.NextTypeVarId++, VarKind::Unification); + auto NewRestTy = new Type(TVar(VarKind::Unification, C.NextTypeVarId++)); pushLeft(TypeIndex::forFieldRest()); - if (!unify(Field1->RestTy, new TField(Field2->Name, Field2->Ty, NewRestTy), DidSwap)) { + if (!unify(Field1.RestTy, new Type(TField(Field2.Name, Field2.Ty, NewRestTy)), DidSwap)) { Success = false; } popLeft(); pushRight(TypeIndex::forFieldRest()); - if (!unify(new TField(Field1->Name, Field1->Ty, NewRestTy), Field2->RestTy, DidSwap)) { + if (!unify(new Type(TField(Field1.Name, Field1.Ty, NewRestTy)), Field2.RestTy, DidSwap)) { Success = false; } popRight(); return Success; } - if (isa(A) && isa(B)) { + if (A->isNil() && B->isField()) { swap(); } - if (isa(A) && isa(B)) { - auto Field = static_cast(A); + if (A->isField() && B->isNil()) { + auto& Field = A->asField(); bool Success = true; pushLeft(TypeIndex::forFieldType()); - CurrentFieldName = Field->Name; - if (!unifyField(Field->Ty, new TAbsent, DidSwap)) { + CurrentFieldName = Field.Name; + if (!unifyField(Field.Ty, new Type(TAbsent()), DidSwap)) { Success = false; } popLeft(); pushLeft(TypeIndex::forFieldRest()); - if (!unify(Field->RestTy, B, DidSwap)) { + if (!unify(Field.RestTy, B, DidSwap)) { Success = false; } popLeft(); @@ -1762,6 +1812,5 @@ namespace bolt { A.unify(); } - } diff --git a/bootstrap/cxx/src/ConsolePrinter.cc b/bootstrap/cxx/src/ConsolePrinter.cc index d11aac455..8b8f98a90 100644 --- a/bootstrap/cxx/src/ConsolePrinter.cc +++ b/bootstrap/cxx/src/ConsolePrinter.cc @@ -193,41 +193,42 @@ namespace bolt { } std::string describe(const Type* Ty) { + Ty = Ty->find(); switch (Ty->getKind()) { case TypeKind::Var: { - auto TV = static_cast(Ty); - if (TV->getVarKind() == VarKind::Rigid) { - return static_cast(TV)->Name; + auto TV = Ty->asVar(); + if (TV.isRigid()) { + return *TV.Name; } - return "a" + std::to_string(TV->Id); + return "a" + std::to_string(TV.Id); } case TypeKind::Arrow: { - auto Y = static_cast(Ty); + auto Y = Ty->asArrow(); std::ostringstream Out; - Out << describe(Y->ParamType) << " -> " << describe(Y->ReturnType); + Out << describe(Y.ParamType) << " -> " << describe(Y.ReturnType); return Out.str(); } case TypeKind::Con: { - auto Y = static_cast(Ty); - return Y->DisplayName; + auto Y = Ty->asCon(); + return Y.DisplayName; } case TypeKind::App: { - auto Y = static_cast(Ty); - return describe(Y->Op) + " " + describe(Y->Arg); + auto Y = Ty->asApp(); + return describe(Y.Op) + " " + describe(Y.Arg); } case TypeKind::Tuple: { std::ostringstream Out; - auto Y = static_cast(Ty); + auto Y = Ty->asTuple(); Out << "("; - if (Y->ElementTypes.size()) { - auto Iter = Y->ElementTypes.begin(); + if (Y.ElementTypes.size()) { + auto Iter = Y.ElementTypes.begin(); Out << describe(*Iter++); - while (Iter != Y->ElementTypes.end()) { + while (Iter != Y.ElementTypes.end()) { Out << ", " << describe(*Iter++); } } @@ -236,8 +237,8 @@ namespace bolt { } case TypeKind::TupleIndex: { - auto Y = static_cast(Ty); - return describe(Y->Ty) + "." + std::to_string(Y->I); + auto Y = Ty->asTupleIndex(); + return describe(Y.Ty) + "." + std::to_string(Y.I); } case TypeKind::Nil: return "{}"; @@ -245,19 +246,19 @@ namespace bolt { return "Abs"; case TypeKind::Present: { - auto Y = static_cast(Ty); - return describe(Y->Ty); + auto Y = Ty->asPresent(); + return describe(Y.Ty); } case TypeKind::Field: { - auto Y = static_cast(Ty); + auto Y = Ty->asField(); std::ostringstream out; - out << "{ " << Y->Name << ": " << describe(Y->Ty); - Ty = Y->RestTy; + out << "{ " << Y.Name << ": " << describe(Y.Ty); + Ty = Y.RestTy; while (Ty->getKind() == TypeKind::Field) { - auto Y = static_cast(Ty); - out << "; " + Y->Name + ": " + describe(Y->Ty); - Ty = Y->RestTy; + auto Y = Ty->asField(); + out << "; " + Y.Name + ": " + describe(Y.Ty); + Ty = Y.RestTy; } if (Ty->getKind() != TypeKind::Nil) { out << "; " + describe(Ty); @@ -561,53 +562,52 @@ namespace bolt { void exitType(const Type* Ty) override { if (shouldUnderline()) { - W.setUnderline(false); + W.setUnderline(false); // FIXME Should set to old value } } - void visitAppType(const TApp *Ty) override { - auto Y = static_cast(Ty); + void visitAppType(const TApp& Ty) override { Path.push_back(TypeIndex::forAppOpType()); - visit(Y->Op); + visit(Ty.Op); Path.pop_back(); W.write(" "); Path.push_back(TypeIndex::forAppArgType()); - visit(Y->Arg); + visit(Ty.Arg); Path.pop_back(); } - void visitVarType(const TVar* Ty) override { - if (Ty->getVarKind() == VarKind::Rigid) { - W.write(static_cast(Ty)->Name); + void visitVarType(const TVar& Ty) override { + if (Ty.isRigid()) { + W.write(*Ty.Name); return; } W.write("a"); - W.write(Ty->Id); + W.write(Ty.Id); } - void visitConType(const TCon *Ty) override { - W.write(Ty->DisplayName); + void visitConType(const TCon& Ty) override { + W.write(Ty.DisplayName); } - void visitArrowType(const TArrow* Ty) override { + void visitArrowType(const TArrow& Ty) override { Path.push_back(TypeIndex::forArrowParamType()); - visit(Ty->ParamType); + visit(Ty.ParamType); Path.pop_back(); W.write(" -> "); Path.push_back(TypeIndex::forArrowReturnType()); - visit(Ty->ReturnType); + visit(Ty.ReturnType); Path.pop_back(); } - void visitTupleType(const TTuple *Ty) override { + void visitTupleType(const TTuple& Ty) override { W.write("("); - if (Ty->ElementTypes.size()) { - auto Iter = Ty->ElementTypes.begin(); + if (Ty.ElementTypes.size()) { + auto Iter = Ty.ElementTypes.begin(); Path.push_back(TypeIndex::forTupleElement(0)); visit(*Iter++); Path.pop_back(); std::size_t I = 1; - while (Iter != Ty->ElementTypes.end()) { + while (Iter != Ty.ElementTypes.end()) { W.write(", "); Path.push_back(TypeIndex::forTupleElement(I++)); visit(*Iter++); @@ -617,47 +617,47 @@ namespace bolt { W.write(")"); } - void visitTupleIndexType(const TTupleIndex *Ty) override { + void visitTupleIndexType(const TTupleIndex& Ty) override { Path.push_back(TypeIndex::forTupleIndexType()); - visit(Ty->Ty); + visit(Ty.Ty); Path.pop_back(); W.write("."); - W.write(Ty->I); + W.write(Ty.I); } - void visitNilType(const TNil *Ty) override { + void visitNilType(const TNil& Ty) override { W.write("{}"); } - void visitAbsentType(const TAbsent *Ty) override { + void visitAbsentType(const TAbsent& Ty) override { W.write("Abs"); } - void visitPresentType(const TPresent *Ty) override { + void visitPresentType(const TPresent& Ty) override { Path.push_back(TypeIndex::forPresentType()); - visit(Ty->Ty); + visit(Ty.Ty); Path.pop_back(); } - void visitFieldType(const TField* Ty) override { + void visitFieldType(const TField& Ty) override { W.write("{ "); - W.write(Ty->Name); + W.write(Ty.Name); W.write(": "); Path.push_back(TypeIndex::forFieldType()); - visit(Ty->Ty); + visit(Ty.Ty); Path.pop_back(); - auto Ty2 = Ty->RestTy; + auto Ty2 = Ty.RestTy; Path.push_back(TypeIndex::forFieldRest()); std::size_t I = 1; - while (Ty2->getKind() == TypeKind::Field) { - auto Y = static_cast(Ty2); + while (Ty2->isField()) { + auto Y = Ty2->asField(); W.write("; "); - W.write(Y->Name); + W.write(Y.Name); W.write(": "); Path.push_back(TypeIndex::forFieldType()); - visit(Y->Ty); + visit(Y.Ty); Path.pop_back(); - Ty2 = Y->RestTy; + Ty2 = Y.RestTy; Path.push_back(TypeIndex::forFieldRest()); ++I; } @@ -730,7 +730,7 @@ namespace bolt { case DiagnosticKind::BindingNotFound: { - auto E = static_cast(D); + auto& E = static_cast(D); writePrefix(E); write("binding "); writeBinding(E.Name); @@ -746,7 +746,7 @@ namespace bolt { case DiagnosticKind::UnexpectedToken: { - auto E = static_cast(D); + auto& E = static_cast(D); writePrefix(E); writeLoc(E.File, E.Actual->getStartLoc()); write(" expected "); @@ -780,7 +780,7 @@ namespace bolt { case DiagnosticKind::UnexpectedString: { - auto E = static_cast(D); + auto& E = static_cast(D); writePrefix(E); writeLoc(E.File, E.Location); write(" unexpected '"); @@ -806,7 +806,7 @@ namespace bolt { case DiagnosticKind::UnificationError: { - auto E = static_cast(D); + auto& E = static_cast(D); auto Left = E.OrigLeft->resolve(E.LeftPath); auto Right = E.OrigRight->resolve(E.RightPath); writePrefix(E); @@ -857,7 +857,7 @@ namespace bolt { case DiagnosticKind::TypeclassMissing: { - auto E = static_cast(D); + auto& E = static_cast(D); writePrefix(E); write("the type class "); writeTypeclassSignature(E.Sig); @@ -869,7 +869,7 @@ namespace bolt { case DiagnosticKind::InstanceNotFound: { - auto E = static_cast(D); + auto& E = static_cast(D); writePrefix(E); write("a type class instance "); writeTypeclassName(E.TypeclassName); @@ -883,7 +883,7 @@ namespace bolt { case DiagnosticKind::TupleIndexOutOfRange: { - auto E = static_cast(D); + auto& E = static_cast(D); writePrefix(E); write("the index "); writeType(E.I); @@ -894,7 +894,7 @@ namespace bolt { case DiagnosticKind::InvalidTypeToTypeclass: { - auto E = static_cast(D); + auto& E = static_cast(D); writePrefix(E); write("the type "); writeType(E.Actual); @@ -911,7 +911,7 @@ namespace bolt { case DiagnosticKind::FieldNotFound: { - auto E = static_cast(D); + auto& E = static_cast(D); writePrefix(E); write("the field '"); write(E.Name); @@ -921,6 +921,16 @@ namespace bolt { break; } + case DiagnosticKind::NotATuple: + { + auto& E = static_cast(D); + writePrefix(E); + write("the type "); + writeType(E.Ty); + write(" is not a tuple.\n"); + break; + } + } } diff --git a/bootstrap/cxx/src/Types.cc b/bootstrap/cxx/src/Types.cc index feffac871..60dce6f7e 100644 --- a/bootstrap/cxx/src/Types.cc +++ b/bootstrap/cxx/src/Types.cc @@ -1,9 +1,8 @@ -#include "zen/config.hpp" -#include "zen/range.hpp" - -#include "bolt/Common.hpp" #include "bolt/Type.hpp" +#include +#include +#include namespace bolt { @@ -13,13 +12,13 @@ namespace bolt { } ZEN_ASSERT(Params.size() == 1); ZEN_ASSERT(Other.Params.size() == 1); - return Params[0]->Id < Other.Params[0]->Id; + return Params[0]->asCon().Id < Other.Params[0]->asCon().Id; } bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { ZEN_ASSERT(Params.size() == 1); ZEN_ASSERT(Other.Params.size() == 1); - return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id; + return Id == Other.Id && Params[0]->asCon().Id == Other.Params[0]->asCon().Id; } bool TypeIndex::operator==(const TypeIndex& Other) const noexcept { @@ -35,36 +34,122 @@ namespace bolt { } } - void TypeIndex::advance(const Type* Ty) { + bool TCon::operator==(const TCon& Other) const { + return Id == Other.Id; + } + + bool TApp::operator==(const TApp& Other) const { + return *Op == *Other.Op && *Arg == *Other.Arg; + } + + bool TVar::operator==(const TVar& Other) const { + return Id == Other.Id; + } + + bool TArrow::operator==(const TArrow& Other) const { + return *ParamType == *Other.ParamType + && *ReturnType == *Other.ReturnType; + } + + bool TTuple::operator==(const TTuple& Other) const { + for (auto [T1, T2]: zen::zip(ElementTypes, Other.ElementTypes)) { + if (*T1 != *T2) { + return false; + } + } + return true; + } + + bool TTupleIndex::operator==(const TTupleIndex& Other) const { + return *Ty == *Other.Ty && I == Other.I; + } + + bool TNil::operator==(const TNil& Other) const { + return true; + } + + bool TField::operator==(const TField& Other) const { + return Name == Other.Name && *Ty == *Other.Ty && *RestTy == *Other.RestTy; + } + + bool TAbsent::operator==(const TAbsent& Other) const { + return true; + } + + bool TPresent::operator==(const TPresent& Other) const { + return *Ty == *Other.Ty; + } + + bool Type::operator==(const Type& Other) const { + if (Kind != Other.Kind) { + return false; + } switch (Kind) { - case TypeIndexKind::End: + case TypeKind::Var: + return Var == Other.Var; + case TypeKind::Con: + return Con == Other.Con; + case TypeKind::Present: + return Present == Other.Present; + case TypeKind::Absent: + return Absent == Other.Absent; + case TypeKind::Arrow: + return Arrow == Other.Arrow; + case TypeKind::Field: + return Field == Other.Field; + case TypeKind::Nil: + return Nil == Other.Nil; + case TypeKind::Tuple: + return Tuple == Other.Tuple; + case TypeKind::TupleIndex: + return TupleIndex == Other.TupleIndex; + case TypeKind::App: + return App == Other.App; + } + } + + void Type::visitEachChild(std::function Proc) { + switch (Kind) { + case TypeKind::Var: + case TypeKind::Absent: + case TypeKind::Nil: + case TypeKind::Con: break; - case TypeIndexKind::AppOpType: - Kind = TypeIndexKind::AppArgType; - break; - case TypeIndexKind::ArrowParamType: - Kind = TypeIndexKind::ArrowReturnType; - break; - case TypeIndexKind::ArrowReturnType: - Kind = TypeIndexKind::End; - break; - case TypeIndexKind::FieldType: - Kind = TypeIndexKind::FieldRestType; - break; - case TypeIndexKind::FieldRestType: - case TypeIndexKind::TupleIndexType: - case TypeIndexKind::PresentType: - case TypeIndexKind::AppArgType: - case TypeIndexKind::TupleElement: + case TypeKind::Arrow: { - auto Tuple = cast(Ty); - if (I+1 < Tuple->ElementTypes.size()) { - ++I; - } else { - Kind = TypeIndexKind::End; + Proc(Arrow.ParamType); + Proc(Arrow.ReturnType); + break; + } + case TypeKind::Tuple: + { + for (auto I = 0; I < Tuple.ElementTypes.size(); ++I) { + Proc(Tuple.ElementTypes[I]); } break; } + case TypeKind::App: + { + Proc(App.Op); + Proc(App.Arg); + break; + } + case TypeKind::Field: + { + Proc(Field.Ty); + Proc(Field.RestTy); + break; + } + case TypeKind::Present: + { + Proc(Present.Ty); + break; + } + case TypeKind::TupleIndex: + { + Proc(TupleIndex.Ty); + break; + } } } @@ -81,49 +166,49 @@ namespace bolt { return Ty2; case TypeKind::Arrow: { - auto Arrow = static_cast(Ty2); + auto Arrow = Ty2->asArrow(); bool Changed = false; - Type* NewParamType = Arrow->ParamType->rewrite(Fn); - if (NewParamType != Arrow->ParamType) { + Type* NewParamType = Arrow.ParamType->rewrite(Fn, Recursive); + if (NewParamType != Arrow.ParamType) { Changed = true; } - auto NewRetTy = Arrow->ReturnType->rewrite(Fn); - if (NewRetTy != Arrow->ReturnType) { + auto NewRetTy = Arrow.ReturnType->rewrite(Fn, Recursive); + if (NewRetTy != Arrow.ReturnType) { Changed = true; } - return Changed ? new TArrow(NewParamType, NewRetTy) : Ty2; + return Changed ? new Type(TArrow(NewParamType, NewRetTy)) : Ty2; } case TypeKind::Con: return Ty2; case TypeKind::App: { - auto App = static_cast(Ty2); - auto NewOp = App->Op->rewrite(Fn); - auto NewArg = App->Arg->rewrite(Fn); - if (NewOp == App->Op && NewArg == App->Arg) { - return App; + auto App = Ty2->asApp(); + auto NewOp = App.Op->rewrite(Fn, Recursive); + auto NewArg = App.Arg->rewrite(Fn, Recursive); + if (NewOp == App.Op && NewArg == App.Arg) { + return Ty2; } - return new TApp(NewOp, NewArg); + return new Type(TApp(NewOp, NewArg)); } case TypeKind::TupleIndex: { - auto Tuple = static_cast(Ty2); - auto NewTy = Tuple->Ty->rewrite(Fn); - return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple; + auto Tuple = Ty2->asTupleIndex(); + auto NewTy = Tuple.Ty->rewrite(Fn, Recursive); + return NewTy != Tuple.Ty ? new Type(TTupleIndex(NewTy, Tuple.I)) : Ty2; } case TypeKind::Tuple: { - auto Tuple = static_cast(Ty2); + auto Tuple = Ty2->asTuple(); bool Changed = false; std::vector NewElementTypes; - for (auto Ty: Tuple->ElementTypes) { - auto NewElementType = Ty->rewrite(Fn); + for (auto Ty: Tuple.ElementTypes) { + auto NewElementType = Ty->rewrite(Fn, Recursive); if (NewElementType != Ty) { Changed = true; } NewElementTypes.push_back(NewElementType); } - return Changed ? new TTuple(NewElementTypes) : Ty2; + return Changed ? new Type(TTuple(NewElementTypes)) : Ty2; } case TypeKind::Nil: return Ty2; @@ -131,272 +216,77 @@ namespace bolt { return Ty2; case TypeKind::Field: { - auto Field = static_cast(Ty2); + auto Field = Ty2->asField(); bool Changed = false; - auto NewTy = Field->Ty->rewrite(Fn); - if (NewTy != Field->Ty) { + auto NewTy = Field.Ty->rewrite(Fn, Recursive); + if (NewTy != Field.Ty) { Changed = true; } - auto NewRestTy = Field->RestTy->rewrite(Fn); - if (NewRestTy != Field->RestTy) { + auto NewRestTy = Field.RestTy->rewrite(Fn, Recursive); + if (NewRestTy != Field.RestTy) { Changed = true; } - return Changed ? new TField(Field->Name, NewTy, NewRestTy) : Ty2; + return Changed ? new Type(TField(Field.Name, NewTy, NewRestTy)) : Ty2; } case TypeKind::Present: { - auto Present = static_cast(Ty2); - auto NewTy = Present->Ty->rewrite(Fn); - if (NewTy == Present->Ty) { + auto Present = Ty2->asPresent(); + auto NewTy = Present.Ty->rewrite(Fn, Recursive); + if (NewTy == Present.Ty) { return Ty2; } - return new TPresent(NewTy); + return new Type(TPresent(NewTy)); } } - - } - - void Type::addTypeVars(TVSet& TVs) { - switch (Kind) { - case TypeKind::Var: - TVs.emplace(static_cast(this)); - break; - case TypeKind::Arrow: - { - auto Arrow = static_cast(this); - Arrow->ParamType->addTypeVars(TVs); - Arrow->ReturnType->addTypeVars(TVs); - break; - } - case TypeKind::Con: - break; - case TypeKind::App: - { - auto App = static_cast(this); - App->Op->addTypeVars(TVs); - App->Arg->addTypeVars(TVs); - break; - } - case TypeKind::TupleIndex: - { - auto Index = static_cast(this); - Index->Ty->addTypeVars(TVs); - break; - } - case TypeKind::Tuple: - { - auto Tuple = static_cast(this); - for (auto Ty: Tuple->ElementTypes) { - Ty->addTypeVars(TVs); - } - break; - } - case TypeKind::Nil: - break; - case TypeKind::Field: - { - auto Field = static_cast(this); - Field->Ty->addTypeVars(TVs); - Field->Ty->addTypeVars(TVs); - break; - } - case TypeKind::Present: - { - auto Present = static_cast(this); - Present->Ty->addTypeVars(TVs); - break; - } - case TypeKind::Absent: - break; - } - } - - bool Type::hasTypeVar(const TVar* TV) { - switch (Kind) { - case TypeKind::Var: - return static_cast(this)->Id == TV->Id; - case TypeKind::Arrow: - { - auto Arrow = static_cast(this); - return Arrow->ParamType->hasTypeVar(TV) || Arrow->ReturnType->hasTypeVar(TV); - } - case TypeKind::Con: - return false; - case TypeKind::App: - { - auto App = static_cast(this); - return App->Op->hasTypeVar(TV) || App->Arg->hasTypeVar(TV); - } - case TypeKind::TupleIndex: - { - auto Index = static_cast(this); - return Index->Ty->hasTypeVar(TV); - } - case TypeKind::Tuple: - { - auto Tuple = static_cast(this); - for (auto Ty: Tuple->ElementTypes) { - if (Ty->hasTypeVar(TV)) { - return true; - } - } - return false; - } - case TypeKind::Nil: - return false; - case TypeKind::Field: - { - auto Field = static_cast(this); - return Field->Ty->hasTypeVar(TV) || Field->RestTy->hasTypeVar(TV); - } - case TypeKind::Present: - { - auto Present = static_cast(this); - return Present->Ty->hasTypeVar(TV); - } - case TypeKind::Absent: - return false; - } - } - - Type* Type::solve() { - return rewrite([](auto Ty) { - if (Ty->getKind() == TypeKind::Var) { - return static_cast(Ty)->find(); - } - return Ty; - }); } Type* Type::substitute(const TVSub &Sub) { return rewrite([&](auto Ty) { - if (isa(Ty)) { - auto TV = static_cast(Ty); - auto Match = Sub.find(TV); + if (Ty->isVar()) { + auto Match = Sub.find(Ty); return Match != Sub.end() ? Match->second->substitute(Sub) : Ty; } return Ty; - }); + }, false); } Type* Type::resolve(const TypeIndex& Index) const noexcept { switch (Index.Kind) { case TypeIndexKind::PresentType: - return cast(this)->Ty; + return this->asPresent().Ty; case TypeIndexKind::AppOpType: - return cast(this)->Op; + return this->asApp().Op; case TypeIndexKind::AppArgType: - return cast(this)->Arg; + return this->asApp().Arg; case TypeIndexKind::TupleIndexType: - return cast(this)->Ty; + return this->asTupleIndex().Ty; case TypeIndexKind::TupleElement: - return cast(this)->ElementTypes[Index.I]; + return this->asTuple().ElementTypes[Index.I]; case TypeIndexKind::ArrowParamType: - return cast(this)->ParamType; + return this->asArrow().ParamType; case TypeIndexKind::ArrowReturnType: - return cast(this)->ReturnType; + return this->asArrow().ReturnType; case TypeIndexKind::FieldType: - return cast(this)->Ty; + return this->asField().Ty; case TypeIndexKind::FieldRestType: - return cast(this)->RestTy; + return this->asField().RestTy; case TypeIndexKind::End: ZEN_UNREACHABLE } ZEN_UNREACHABLE } - bool Type::operator==(const Type& Other) const noexcept { - switch (Kind) { - case TypeKind::Var: - if (Other.Kind != TypeKind::Var) { - return false; - } - return static_cast(this)->Id == static_cast(Other).Id; - case TypeKind::Tuple: - { - if (Other.Kind != TypeKind::Tuple) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - if (A.ElementTypes.size() != B.ElementTypes.size()) { - return false; - } - for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) { - if (*T1 != *T2) { - return false; - } - } - return true; + TVSet Type::getTypeVars() { + TVSet Out; + std::function visit = [&](Type* Ty) { + if (Ty->isVar()) { + Out.emplace(Ty); + return; } - case TypeKind::TupleIndex: - { - if (Other.Kind != TypeKind::TupleIndex) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - return A.I == B.I && *A.Ty == *B.Ty; - } - case TypeKind::Con: - { - if (Other.Kind != TypeKind::Con) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - if (A.Id != B.Id) { - return false; - } - return true; - } - case TypeKind::App: - { - if (Other.Kind != TypeKind::App) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - return *A.Op == *B.Op && *A.Arg == *B.Arg; - } - case TypeKind::Arrow: - { - if (Other.Kind != TypeKind::Arrow) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - return *A.ParamType == *B.ParamType && *A.ReturnType == *B.ReturnType; - } - case TypeKind::Absent: - if (Other.Kind != TypeKind::Absent) { - return false; - } - return true; - case TypeKind::Nil: - if (Other.Kind != TypeKind::Nil) { - return false; - } - return true; - case TypeKind::Present: - { - if (Other.Kind != TypeKind::Present) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - return *A.Ty == *B.Ty; - } - case TypeKind::Field: - { - if (Other.Kind != TypeKind::Field) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - return A.Name == B.Name && *A.Ty == *B.Ty && *A.RestTy == *B.RestTy; - } - } + Ty->visitEachChild(visit); + }; + visit(this); + return Out; } TypeIterator Type::begin() { @@ -407,14 +297,13 @@ namespace bolt { return TypeIterator { this, getEndIndex() }; } - TypeIndex Type::getStartIndex() { + TypeIndex Type::getStartIndex() const { switch (Kind) { case TypeKind::Arrow: return TypeIndex::forArrowParamType(); case TypeKind::Tuple: { - auto Tuple = static_cast(this); - if (Tuple->ElementTypes.empty()) { + if (asTuple().ElementTypes.empty()) { return TypeIndex(TypeIndexKind::End); } return TypeIndex::forTupleElement(0); @@ -426,29 +315,38 @@ namespace bolt { } } - TypeIndex Type::getEndIndex() { + TypeIndex Type::getEndIndex() const { return TypeIndex(TypeIndexKind::End); } - - inline Type* TVar::find() { - TVar* Curr = this; - for (;;) { - auto Keep = Curr->Parent; - if (Keep->getKind() != TypeKind::Var || Keep == Curr) { - return Keep; - } - auto TV = static_cast(Keep); - Curr->Parent = TV->Parent; - Curr = TV; + bool Type::hasTypeVar(Type* TV) const { + switch (Kind) { + case TypeKind::Var: + return Var.Id == TV->asVar().Id; + case TypeKind::Con: + case TypeKind::Absent: + case TypeKind::Nil: + return false; + case TypeKind::App: + return App.Op->hasTypeVar(TV) || App.Arg->hasTypeVar(TV); + case TypeKind::Tuple: + for (auto Ty: Tuple.ElementTypes) { + if (Ty->hasTypeVar(TV)) { + return true; + } + } + return false; + case TypeKind::TupleIndex: + return TupleIndex.Ty->hasTypeVar(TV); + case TypeKind::Field: + return Field.Ty->hasTypeVar(TV) || Field.RestTy->hasTypeVar(TV); + case TypeKind::Arrow: + return Arrow.ParamType->hasTypeVar(TV) || Arrow.ReturnType->hasTypeVar(TV); + case TypeKind::Present: + return Present.Ty->hasTypeVar(TV); } } - void TVar::set(Type* Ty) { - auto Root = find(); - // It is not possible to set a solution twice. - ZEN_ASSERT(Root->getKind() == TypeKind::Var); - static_cast(Root)->Parent = Ty; - } - } + +