diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index ae332db1f..1fb9b2793 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -1497,6 +1497,9 @@ namespace bolt { Fields(Fields), RBrace(RBrace) {} + Token* getFirstToken() const override; + Token* getLastToken() const override; + }; class Statement : public Node { @@ -1800,6 +1803,7 @@ namespace bolt { class PubKeyword* PubKeyword; class StructKeyword* StructKeyword; IdentifierAlt* Name; + std::vector Vars; class BlockStart* BlockStart; std::vector Fields; @@ -1807,12 +1811,14 @@ namespace bolt { class PubKeyword* PubKeyword, class StructKeyword* StructKeyword, IdentifierAlt* Name, + std::vector Vars, class BlockStart* BlockStart, std::vector Fields ): Node(NodeKind::RecordDeclaration), PubKeyword(PubKeyword), StructKeyword(StructKeyword), Name(Name), + Vars(Vars), BlockStart(BlockStart), Fields(Fields) {} diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index d604e64c3..d4da9385a 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -166,6 +166,9 @@ namespace bolt { class Checker { + friend class Unifier; + friend class UnificationFrame; + const LanguageConfig& Config; DiagnosticEngine& DE; @@ -178,14 +181,10 @@ namespace bolt { Graph RefGraph; - std::unordered_map CallGraph; - std::unordered_map> InstanceMap; std::vector Contexts; - TVSub Solution; - /** * The queue that is used during solving to store any unsolved constraints. */ @@ -208,15 +207,14 @@ namespace bolt { Type* inferTypeExpression(TypeExpression* TE); Type* inferLiteral(Literal* Lit); - void inferBindings(Pattern* Pattern, Type* T, ConstraintSet* Constraints, TVSet* TVs); - void inferBindings(Pattern* Pattern, Type* T); + Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); void infer(Node* node); void inferLetDeclaration(LetDeclaration* N); Constraint* convertToConstraint(ConstraintExpression* C); - TCon* createPrimConType(); + TCon* createConType(ByteString Name); TVar* createTypeVar(); TVarRigid* createRigidVar(ByteString Name); InferContext* createInferContext(TVSet* TVs = new TVSet, ConstraintSet* Constraints = new ConstraintSet); @@ -239,8 +237,6 @@ namespace bolt { */ Type* lookupMono(ByteString Name); - InferContext* lookupCall(Node* Source, SymbolPath Path); - /** * Get the return type for the current context. If none could be found, the program will abort. */ @@ -252,10 +248,6 @@ namespace bolt { void propagateClasses(TypeclassContext& Classes, Type* Ty); void propagateClassTycon(TypeclassId& Class, TCon* Ty); - Type* simplify(Type* Ty); - - Type* find(Type* Ty); - /** * Assign a type to a unification variable. * @@ -268,18 +260,21 @@ namespace bolt { */ void join(TVar* A, Type* B); + // Unification parameters Type* OrigLeft; Type* OrigRight; TypePath LeftPath; TypePath RightPath; + ByteString CurrentFieldName; Node* Source; bool unify(Type* A, Type* B); void unifyError(); + void solveCEqual(CEqual* C); - void solve(Constraint* Constraint, TVSub& Solution); + void solve(Constraint* Constraint); void populate(SourceFile* SF); @@ -293,17 +288,22 @@ namespace bolt { Checker(const LanguageConfig& Config, DiagnosticEngine& DE); + /** + * \internal + */ + Type* simplifyType(Type* Ty); + void check(SourceFile* SF); - inline Type* getBoolType() { + inline Type* getBoolType() const { return BoolType; } - inline Type* getStringType() { + inline Type* getStringType() const { return StringType; } - inline Type* getIntType() { + inline Type* getIntType() const { return IntType; } diff --git a/include/bolt/DiagnosticEngine.hpp b/include/bolt/DiagnosticEngine.hpp index 0038315d9..4c4e67834 100644 --- a/include/bolt/DiagnosticEngine.hpp +++ b/include/bolt/DiagnosticEngine.hpp @@ -6,6 +6,8 @@ #include #include "bolt/ByteString.hpp" +#include "bolt/CST.hpp" +#include "bolt/Type.hpp" namespace bolt { @@ -60,6 +62,98 @@ namespace bolt { Magenta, }; + enum StyleFlags : unsigned { + StyleFlags_None = 0, + StyleFlags_Bold = 1 << 0, + StyleFlags_Underline = 1 << 1, + StyleFlags_Italic = 1 << 2, + }; + + class Style { + + unsigned Flags = StyleFlags_None; + + Color FgColor = Color::None; + Color BgColor = Color::None; + + public: + + Color getForegroundColor() const noexcept { + return FgColor; + } + + Color getBackgroundColor() const noexcept { + return BgColor; + } + + void setForegroundColor(Color NewColor) noexcept { + FgColor = NewColor; + } + + void setBackgroundColor(Color NewColor) noexcept { + BgColor = NewColor; + } + + bool hasForegroundColor() const noexcept { + return FgColor != Color::None; + } + + bool hasBackgroundColor() const noexcept { + return BgColor != Color::None; + } + + void clearForegroundColor() noexcept { + FgColor = Color::None; + } + + void clearBackgroundColor() noexcept { + BgColor = Color::None; + } + + bool isUnderline() const noexcept { + return Flags & StyleFlags_Underline; + } + + bool isItalic() const noexcept { + return Flags & StyleFlags_Italic; + } + + bool isBold() const noexcept { + return Flags & StyleFlags_Bold; + } + + void setUnderline(bool Enable) noexcept { + if (Enable) { + Flags |= StyleFlags_Underline; + } else { + Flags &= ~StyleFlags_Underline; + } + } + + void setItalic(bool Enable) noexcept { + if (Enable) { + Flags |= StyleFlags_Italic; + } else { + Flags &= ~StyleFlags_Italic; + } + } + + void setBold(bool Enable) noexcept { + if (Enable) { + Flags |= StyleFlags_Bold; + } else { + Flags &= ~StyleFlags_Bold; + } + } + + void reset() noexcept { + FgColor = Color::None; + BgColor = Color::None; + Flags = 0; + } + + }; + /** * Prints any diagnostic message that was added to it to the console. */ @@ -67,8 +161,12 @@ namespace bolt { std::ostream& Out; + Style ActiveStyle; + void setForegroundColor(Color C); void setBackgroundColor(Color C); + void applyStyles(); + void setBold(bool Enable); void setItalic(bool Enable); void setUnderline(bool Enable); @@ -99,6 +197,7 @@ namespace bolt { void writePrefix(const Diagnostic& D); void writeBinding(const ByteString& Name); void writeType(std::size_t I); + void writeType(const Type* Ty, const TypePath& Underline); void writeType(const Type* Ty); void writeLoc(const TextFile& File, const TextLoc& Loc); void writeTypeclassName(const ByteString& Name); diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index d67e50d0b..28763dfc7 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -1,6 +1,7 @@ #pragma once +#include #include #include #include @@ -23,6 +24,7 @@ namespace bolt { ClassNotFound, TupleIndexOutOfRange, InvalidTypeToTypeclass, + FieldNotFound, }; class Diagnostic : std::runtime_error { @@ -88,14 +90,14 @@ namespace bolt { class UnificationErrorDiagnostic : public Diagnostic { public: - Type* Left; - Type* Right; + Type* OrigLeft; + Type* OrigRight; TypePath LeftPath; TypePath RightPath; Node* Source; - inline UnificationErrorDiagnostic(Type* Left, Type* Right, TypePath LeftPath, TypePath RightPath, Node* Source): - Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {} + inline UnificationErrorDiagnostic(Type* OrigLeft, Type* OrigRight, TypePath LeftPath, TypePath RightPath, Node* Source): + Diagnostic(DiagnosticKind::UnificationError), OrigLeft(OrigLeft), OrigRight(OrigRight), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {} inline Node* getNode() const override { return Source; @@ -171,4 +173,17 @@ namespace bolt { }; + class FieldNotFoundDiagnostic : public Diagnostic { + public: + + ByteString Name; + Type* Ty; + TypePath Path; + Node* Source; + + inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source): + Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {} + + }; + } diff --git a/include/bolt/Parser.hpp b/include/bolt/Parser.hpp index e82a1b45b..ef20b386d 100644 --- a/include/bolt/Parser.hpp +++ b/include/bolt/Parser.hpp @@ -82,6 +82,7 @@ namespace bolt { MatchExpression* parseMatchExpression(); Expression* parseMemberExpression(); + RecordExpression* parseRecordExpression(); Expression* parsePrimitiveExpression(); ConstraintExpression* parseConstraintExpression(); diff --git a/include/bolt/Type.hpp b/include/bolt/Type.hpp index 704449564..cc7f8c257 100644 --- a/include/bolt/Type.hpp +++ b/include/bolt/Type.hpp @@ -1,6 +1,8 @@ #pragma once +#include +#include #include #include #include @@ -28,10 +30,15 @@ namespace bolt { }; enum class TypeIndexKind { + AppOpType, + AppArgType, ArrowParamType, ArrowReturnType, - ConArg, TupleElement, + FieldType, + FieldRestType, + TupleIndexType, + PresentType, End, }; @@ -59,22 +66,42 @@ namespace bolt { void advance(const Type* Ty); - static TypeIndex forArrowReturnType() { - return { TypeIndexKind::ArrowReturnType }; + static TypeIndex forFieldType() { + return { TypeIndexKind::FieldType }; + } + + static TypeIndex forFieldRest() { + return { TypeIndexKind::FieldRestType }; } static TypeIndex forArrowParamType(std::size_t I) { return { TypeIndexKind::ArrowParamType, I }; } - static TypeIndex forConArg(std::size_t I) { - return { TypeIndexKind::ConArg, I }; + static TypeIndex forArrowReturnType() { + return { TypeIndexKind::ArrowReturnType }; } static TypeIndex forTupleElement(std::size_t I) { return { TypeIndexKind::TupleElement, I }; } + static TypeIndex forAppOpType() { + return { TypeIndexKind::AppOpType }; + } + + static TypeIndex forAppArgType() { + return { TypeIndexKind::AppArgType }; + } + + static TypeIndex forTupleIndexType() { + return { TypeIndexKind::TupleIndexType }; + } + + static TypeIndex forPresentType() { + return { TypeIndexKind::PresentType }; + } + }; class TypeIterator { @@ -116,9 +143,14 @@ namespace bolt { enum class TypeKind : unsigned char { Var, Con, + App, Arrow, Tuple, TupleIndex, + Field, + Nil, + Absent, + Present, }; class Type { @@ -146,8 +178,18 @@ namespace bolt { return Out; } + /** + * Rewrites the entire substructure of a type to another one. + * + * \param Recursive If true, a succesfull local rewritten type will be again + * rewriten until it encounters some terminals. + */ + Type* rewrite(std::function Fn, bool Recursive = false); + Type* substitute(const TVSub& Sub); + Type* solve(); + TypeIterator begin(); TypeIterator end(); @@ -176,11 +218,10 @@ namespace bolt { public: const size_t Id; - std::vector Args; ByteString DisplayName; - inline TCon(const size_t Id, std::vector Args, ByteString DisplayName): - Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {} + inline TCon(const size_t Id, ByteString DisplayName): + Type(TypeKind::Con), Id(Id), DisplayName(DisplayName) {} static bool classof(const Type* Ty) { return Ty->getKind() == TypeKind::Con; @@ -188,12 +229,30 @@ namespace bolt { }; + class TApp : public Type { + public: + + Type* Op; + Type* Arg; + + inline TApp(Type* Op, Type* Arg): + Type(TypeKind::App), Op(Op), Arg(Arg) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::App; + } + + }; + enum class VarKind { Rigid, Unification, }; class TVar : public Type { + + Type* Parent = this; + public: const size_t Id; @@ -208,6 +267,10 @@ namespace bolt { return VK; } + Type* find(); + + void set(Type* Ty); + static bool classof(const Type* Ty) { return Ty->getKind() == TypeKind::Var; } @@ -272,6 +335,215 @@ namespace bolt { }; + class TNil : public Type { + public: + + inline TNil(): + Type(TypeKind::Nil) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Nil; + } + + }; + + class TField : public Type { + public: + + ByteString Name; + Type* Ty; + Type* RestTy; + + inline TField( + ByteString Name, + Type* Ty, + Type* RestTy + ): Type(TypeKind::Field), + Name(Name), + Ty(Ty), + RestTy(RestTy) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Field; + } + + }; + + class TAbsent : public Type { + public: + + inline TAbsent(): + Type(TypeKind::Absent) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Absent; + } + + }; + + class TPresent : public Type { + public: + + Type* Ty; + + inline TPresent(Type* Ty): + Type(TypeKind::Present), Ty(Ty) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Present; + } + + }; + + template + class TypeVisitorBase { + protected: + + template + using C = std::conditional::type; + + virtual void enterType(C* Ty) {} + virtual void exitType(C* Ty) {} + + virtual void visitVarType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitAppType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitPresentType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitConType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitArrowType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitTupleType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitTupleIndexType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitAbsentType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitFieldType(C* Ty) { + visitEachChild(Ty); + } + + virtual void visitNilType(C* Ty) { + visitEachChild(Ty); + } + + public: + + void visitEachChild(C* Ty) { + switch (Ty->getKind()) { + case TypeKind::Var: + case TypeKind::Absent: + case TypeKind::Nil: + case TypeKind::Con: + break; + case TypeKind::Arrow: + { + auto Arrow = static_cast*>(Ty); + for (auto I = 0; I < Arrow->ParamTypes.size(); ++I) { + visit(Arrow->ParamTypes[I]); + } + visit(Arrow->ReturnType); + break; + } + case TypeKind::Tuple: + { + auto Tuple = static_cast*>(Ty); + for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) { + visit(Tuple->ElementTypes[I]); + } + break; + } + case TypeKind::App: + { + auto App = static_cast*>(Ty); + visit(App->Op); + visit(App->Arg); + break; + } + case TypeKind::Field: + { + auto Field = static_cast*>(Ty); + visit(Field->Ty); + visit(Field->RestTy); + break; + } + case TypeKind::Present: + { + auto Present = static_cast*>(Ty); + visit(Present->Ty); + break; + } + case TypeKind::TupleIndex: + { + auto Index = static_cast*>(Ty); + visit(Index->Ty); + break; + } + } + } + + void visit(C* Ty) { + enterType(Ty); + switch (Ty->getKind()) { + case TypeKind::Present: + visitPresentType(static_cast*>(Ty)); + break; + case TypeKind::Absent: + visitAbsentType(static_cast*>(Ty)); + break; + case TypeKind::Nil: + visitNilType(static_cast*>(Ty)); + break; + case TypeKind::Field: + visitFieldType(static_cast*>(Ty)); + break; + case TypeKind::Con: + visitConType(static_cast*>(Ty)); + break; + case TypeKind::Arrow: + visitArrowType(static_cast*>(Ty)); + break; + case TypeKind::Var: + visitVarType(static_cast*>(Ty)); + break; + case TypeKind::Tuple: + visitTupleType(static_cast*>(Ty)); + break; + case TypeKind::App: + visitAppType(static_cast*>(Ty)); + break; + case TypeKind::TupleIndex: + visitTupleIndexType(static_cast*>(Ty)); + break; + } + exitType(Ty); + } + + virtual ~TypeVisitorBase() {} + + }; + + using TypeVisitor = TypeVisitorBase; + using ConstTypeVisitor = TypeVisitorBase; + // template // struct DerefHash { // std::size_t operator()(const T& Value) const noexcept { diff --git a/src/CST.cc b/src/CST.cc index 16612f0d2..54939c6b9 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -417,6 +417,22 @@ namespace bolt { return BlockStart; } + Token* RecordExpressionField::getFirstToken() const { + return Name; + } + + Token* RecordExpressionField::getLastToken() const { + return E->getLastToken(); + } + + Token* RecordExpression::getFirstToken() const { + return LBrace; + } + + Token* RecordExpression::getLastToken() const { + return RBrace; + } + Token* MemberExpression::getFirstToken() const { return E->getFirstToken(); } diff --git a/src/Checker.cc b/src/Checker.cc index 96a0ccffd..4e789bc03 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -3,18 +3,23 @@ // TODO (maybe) make unficiation work like union-find in find() +// TODO remove Args in TCon and just use it as a constant +// TODO make TApp traversable with TupleIndex + // TODO make simplify() rewrite the types in-place such that a reference too (Bool, Int).0 becomes Bool -// TODO Fix TVSub to use TVar.Id instead of the pointer address +// TODO Add a check for datatypes that create infinite structures. -// TODO Deferred diagnostics +// TODO see if we can merge UnificationError diagnostics so that we get a list of **all** types that were wrong on a given node #include #include #include +#include #include "llvm/Support/Casting.h" +#include "bolt/Type.hpp" #include "zen/config.hpp" #include "zen/range.hpp" @@ -58,11 +63,38 @@ namespace bolt { } } + Type* Checker::simplifyType(Type* Ty) { + + return Ty->rewrite([&](auto Ty) { + + if (Ty->getKind() == TypeKind::Var) { + Ty = static_cast(Ty)->find(); + } + + if (Ty->getKind() == TypeKind::TupleIndex) { + auto Index = static_cast(Ty); + auto MaybeTuple = simplifyType(Index->Ty); + if (MaybeTuple->getKind() == TypeKind::Tuple) { + auto Tuple = static_cast(MaybeTuple); + if (Index->I >= Tuple->ElementTypes.size()) { + DE.add(Tuple, Index->I); + } else { + Ty = simplifyType(Tuple->ElementTypes[Index->I]); + } + } + } + + return Ty; + + }, /*Recursive=*/true); + + } + Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): Config(Config), DE(DE) { - BoolType = new TCon(NextConTypeId++, {}, "Bool"); - IntType = new TCon(NextConTypeId++, {}, "Int"); - StringType = new TCon(NextConTypeId++, {}, "String"); + BoolType = createConType("Bool"); + IntType = createConType("Int"); + StringType = createConType("String"); } Scheme* Checker::lookup(ByteString Name) { @@ -233,6 +265,92 @@ namespace bolt { // These declarations will be handled separately in check() break; + case NodeKind::VariantDeclaration: + { + auto Decl = static_cast(X); + + auto& ParentCtx = getContext(); + auto Ctx = createInferContext(); + Contexts.push_back(Ctx); + + std::vector Vars; + for (auto TE: Decl->TVs) { + auto TV = createRigidVar(TE->Name->getCanonicalText()); + Ctx->TVs->emplace(TV); + Vars.push_back(TV); + } + + Type* Ty = createConType(Decl->Name->getCanonicalText()); + + // Must be added early so we can create recursive types + ParentCtx.Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty)); + + for (auto Member: Decl->Members) { + switch (Member->getKind()) { + case NodeKind::TupleVariantDeclarationMember: + { + auto TupleMember = static_cast(Member); + auto RetTy = Ty; + for (auto Var: Vars) { + RetTy = new TApp(RetTy, Var); + } + std::vector ParamTypes; + for (auto Element: TupleMember->Elements) { + ParamTypes.push_back(inferTypeExpression(Element)); + } + ParentCtx.Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Ctx->TVs, Ctx->Constraints, new TArrow(ParamTypes, RetTy))); + break; + } + case NodeKind::RecordVariantDeclarationMember: + { + // TODO + break; + } + default: + ZEN_UNREACHABLE + } + } + + Contexts.pop_back(); + + break; + } + + case NodeKind::RecordDeclaration: + { + auto Decl = static_cast(X); + + auto& ParentCtx = getContext(); + auto Ctx = createInferContext(); + Contexts.push_back(Ctx); + std::vector Vars; + for (auto TE: Decl->Vars) { + auto TV = createRigidVar(TE->Name->getCanonicalText()); + Ctx->TVs->emplace(TV); + Vars.push_back(TV); + } + + auto Name = Decl->Name->getCanonicalText(); + auto Ty = createConType(Name); + + // Must be added early so we can create recursive types + ParentCtx.Env.emplace(Name, new Forall(Ty)); + + // Corresponds to the logic of one branch of a VaraintDeclarationMember + Type* FieldsTy = new TNil(); + for (auto Field: Decl->Fields) { + FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy); + } + Type* RetTy = Ty; + for (auto TV: Vars) { + RetTy = new TApp(RetTy, TV); + } + Contexts.pop_back(); + addBinding(Name, new Forall(Ctx->TVs, Ctx->Constraints, new TArrow({ FieldsTy }, RetTy))); + + break; + } + default: ZEN_UNREACHABLE @@ -313,7 +431,7 @@ namespace bolt { // e.g. Bool, which causes the type assert to also collapse to e.g. // Bool -> Bool -> Bool. for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { - addConstraint(new CEqual(Param, TE->getType())); + addConstraint(new CEqual(Param, TE->getType(), TE)); } } @@ -338,12 +456,14 @@ namespace bolt { } } + Type* BindTy; if (HasContext) { Contexts.pop_back(); - inferBindings(Let->Pattern, Ty, Let->Ctx->Constraints, Let->Ctx->TVs); + BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs); } else { - inferBindings(Let->Pattern, Ty); + BindTy = inferPattern(Let->Pattern); } + addConstraint(new CEqual(BindTy, Ty, Let)); } @@ -364,9 +484,7 @@ namespace bolt { for (auto Param: Decl->Params) { // TODO incorporate Param->TypeAssert or make it a kind of pattern - TVar* TV = createTypeVar(); - inferBindings(Param->Pattern, TV); - ParamTypes.push_back(TV); + ParamTypes.push_back(inferPattern(Param->Pattern)); } if (Decl->Body) { @@ -438,6 +556,11 @@ namespace bolt { break; } + case NodeKind::VariantDeclaration: + case NodeKind::RecordDeclaration: + // Nothing to do for a type-level declaration + break; + case NodeKind::IfStatement: { auto IfStmt = static_cast(N); @@ -482,6 +605,10 @@ namespace bolt { } + TCon* Checker::createConType(ByteString Name) { + return new TCon(NextConTypeId++, Name); + } + TVarRigid* Checker::createRigidVar(ByteString Name) { auto TV = new TVarRigid(NextTypeVarId++, Name); Contexts.back()->TVs->emplace(TV); @@ -533,7 +660,7 @@ namespace bolt { // been solved, with some unification variables being erased. To make // sure we instantiate unification variables that are still in use // we solve before substituting. - return simplify(F->Type)->substitute(Sub); + return simplifyType(F->Type)->substitute(Sub); } } @@ -568,15 +695,28 @@ namespace bolt { case NodeKind::ReferenceTypeExpression: { auto RefTE = static_cast(N); - auto Ty = lookupMono(RefTE->Name->getCanonicalText()); - if (Ty == nullptr) { + auto Scm = lookup(RefTE->Name->getCanonicalText()); + Type* Ty; + if (Scm == nullptr) { DE.add(RefTE->Name->getCanonicalText(), RefTE->Name); Ty = createTypeVar(); + } else { + Ty = instantiate(Scm, RefTE); } N->setType(Ty); return Ty; } + case NodeKind::AppTypeExpression: + { + auto AppTE = static_cast(N); + Type* Ty = inferTypeExpression(AppTE->Op); + for (auto Arg: AppTE->Args) { + Ty = new TApp(Ty, inferTypeExpression(Arg)); + } + return Ty; + } + case NodeKind::VarTypeExpression: { auto VarTE = static_cast(N); @@ -588,8 +728,9 @@ namespace bolt { Ty = createRigidVar(VarTE->Name->getCanonicalText()); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); } + ZEN_ASSERT(Ty->getKind() == TypeKind::Var); N->setType(Ty); - return Ty; + return static_cast(Ty); } case NodeKind::TupleTypeExpression: @@ -642,6 +783,19 @@ namespace bolt { } } + Type* sortRow(Type* Ty) { + std::map Fields; + while (Ty->getKind() == TypeKind::Field) { + auto Field = static_cast(Ty); + Fields.emplace(Field->Name, Field); + Ty = Field->RestTy; + } + for (auto [Name, Field]: Fields) { + Ty = new TField(Name, Field->Ty, Ty); + } + return Ty; + } + Type* Checker::inferExpression(Expression* X) { Type* Ty; @@ -661,9 +815,10 @@ namespace bolt { for (auto Case: Match->Cases) { auto NewCtx = createInferContext(); Contexts.push_back(NewCtx); - inferBindings(Case->Pattern, ValTy); - auto ResTy = inferExpression(Case->Expression); - addConstraint(new CEqual(ResTy, Ty, Case->Expression)); + auto PattTy = inferPattern(Case->Pattern); + addConstraint(new CEqual(PattTy, ValTy, X)); + auto ExprTy = inferExpression(Case->Expression); + addConstraint(new CEqual(ExprTy, Ty, Case->Expression)); Contexts.pop_back(); } if (!Match->Value) { @@ -672,6 +827,17 @@ namespace bolt { break; } + case NodeKind::RecordExpression: + { + auto Record = static_cast(X); + Ty = new TNil(); + for (auto [Field, Comma]: Record->Fields) { + Ty = new TField(Field->Name->getCanonicalText(), new TPresent(inferExpression(Field->getExpression())), Ty); + } + Ty = sortRow(Ty); + break; + } + case NodeKind::ConstantExpression: { auto Const = static_cast(X); @@ -743,16 +909,20 @@ namespace bolt { case NodeKind::MemberExpression: { auto Member = static_cast(X); + auto ExprTy = inferExpression(Member->E); switch (Member->Name->getKind()) { case NodeKind::IntegerLiteral: { auto I = static_cast(Member->Name); - Ty = new TTupleIndex(inferExpression(Member->E), I->getInteger()); + Ty = new TTupleIndex(ExprTy, I->getInteger()); break; } case NodeKind::Identifier: { - // TODO + auto K = static_cast(Member->Name); + Ty = createTypeVar(); + auto RestTy = createTypeVar(); + addConstraint(new CEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member)); break; } default: @@ -778,9 +948,8 @@ namespace bolt { return Ty; } - void Checker::inferBindings( + Type* Checker::inferPattern( Pattern* Pattern, - Type* Type, ConstraintSet* Constraints, TVSet* TVs ) { @@ -790,15 +959,39 @@ namespace bolt { case NodeKind::BindPattern: { auto P = static_cast(Pattern); - addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Type)); - break; + auto Ty = createTypeVar(); + addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty)); + return Ty; + } + + case NodeKind::NamedPattern: + { + auto P = static_cast(Pattern); + auto Scm = lookup(P->Name->getCanonicalText()); + std::vector ParamTypes; + for (auto P2: P->Patterns) { + ParamTypes.push_back(inferPattern(P2, Constraints, TVs)); + } + if (!Scm) { + DE.add(P->Name->getCanonicalText(), P->Name); + return createTypeVar(); + } + auto Ty = instantiate(Scm, P); + auto RetTy = createTypeVar(); + addConstraint(new CEqual(Ty, new TArrow(ParamTypes, RetTy), P)); + return RetTy; + } + + case NodeKind::NestedPattern: + { + auto P = static_cast(Pattern); + return inferPattern(P->P, Constraints, TVs); } case NodeKind::LiteralPattern: { auto P = static_cast(Pattern); - addConstraint(new CEqual(inferLiteral(P->Literal), Type, P)); - break; + return inferLiteral(P->Literal); } default: @@ -808,10 +1001,6 @@ namespace bolt { } - void Checker::inferBindings(Pattern* Pattern, Type* Type) { - inferBindings(Pattern, Type, new ConstraintSet, new TVSet); - } - Type* Checker::inferLiteral(Literal* L) { Type* Ty; switch (L->getKind()) { @@ -927,7 +1116,7 @@ namespace bolt { // This is ugly but it works. Scan all type variables local to this // declaration and add the classes that they require to Actual. for (auto Ty: *Decl->Ctx->TVs) { - auto S = Ty->substitute(C.Solution); + auto S = Ty->solve(); if (llvm::isa(S)) { auto TV = static_cast(S); for (auto Class: TV->Contexts) { @@ -995,6 +1184,10 @@ namespace bolt { } + Type* Checker::getType(TypedNode *Node) { + return Node->getType()->solve(); + } + void Checker::check(SourceFile *SF) { auto RootContext = createInferContext(); Contexts.push_back(RootContext); @@ -1042,11 +1235,11 @@ namespace bolt { } infer(SF); Contexts.pop_back(); - solve(new CMany(*RootContext->Constraints), Solution); + solve(new CMany(*RootContext->Constraints)); checkTypeclassSigs(SF); } - void Checker::solve(Constraint* Constraint, TVSub& Solution) { + void Checker::solve(Constraint* Constraint) { Queue.push_back(Constraint); @@ -1094,12 +1287,13 @@ namespace bolt { if (Con1->Id != Con2-> Id) { return false; } - ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); - for (auto [T1, T2]: zen::zip(Con1->Args, Con2->Args)) { - if (!assignableTo(T1, T2)) { - return false; - } - } + // TODO must handle a TApp + // ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); + // for (auto [T1, T2]: zen::zip(Con1->Args, Con2->Args)) { + // if (!assignableTo(T1, T2)) { + // return false; + // } + // } return true; } ZEN_UNREACHABLE @@ -1112,19 +1306,21 @@ namespace bolt { for (auto Instance: Match->second) { if (assignableTo(Ty, Instance->TypeExps[0]->getType())) { std::vector S; - for (auto Arg: Ty->Args) { - TypeclassContext Classes; - // TODO - S.push_back(Classes); - } + // TODO handle TApp + // for (auto Arg: Ty->Args) { + // TypeclassContext Classes; + // // TODO + // S.push_back(Classes); + // } return S; } } } DE.add(Class, Ty, Source); - for (auto Arg: Ty->Args) { - S.push_back({}); - } + // TODO handle TApp + // for (auto Arg: Ty->Args) { + // S.push_back({}); + // } return S; } @@ -1145,114 +1341,15 @@ namespace bolt { void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty) { auto S = findInstanceContext(Ty, Class); - for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) { - propagateClasses(Classes, Arg); - } + // TODO handle TApp + // for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) { + // propagateClasses(Classes, Arg); + // } }; - void Checker::solveCEqual(CEqual* C) { - // std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; - OrigLeft = C->Left; - OrigRight = C->Right; - Source = C->Source; - unify(C->Left, C->Right); - LeftPath = {}; - RightPath = {}; - } - - Type* Checker::find(Type* Ty) { - while (Ty->getKind() == TypeKind::Var) { - auto Match = Solution.find(static_cast(Ty)); - if (Match == Solution.end()) { - break; - } - Ty = Match->second; - } - return Ty; - } - - Type* Checker::simplify(Type* Ty) { - - Ty = find(Ty); - - switch (Ty->getKind()) { - - case TypeKind::Var: - break; - - case TypeKind::Tuple: - { - auto Tuple = static_cast(Ty); - bool Changed = false; - std::vector NewElementTypes; - for (auto Ty: Tuple->ElementTypes) { - auto NewElementType = simplify(Ty); - if (NewElementType != Ty) { - Changed = true; - } - NewElementTypes.push_back(NewElementType); - } - return Changed ? new TTuple(NewElementTypes) : Ty; - } - - case TypeKind::Arrow: - { - auto Arrow = static_cast(Ty); - bool Changed = false; - std::vector NewParamTys; - for (auto ParamTy: Arrow->ParamTypes) { - auto NewParamTy = simplify(ParamTy); - if (NewParamTy != ParamTy) { - Changed = true; - } - NewParamTys.push_back(NewParamTy); - } - auto NewRetTy = simplify(Arrow->ReturnType); - if (NewRetTy != Arrow->ReturnType) { - Changed = true; - } - Ty = Changed ? new TArrow(NewParamTys, NewRetTy) : Arrow; - break; - } - - case TypeKind::Con: - { - auto Con = static_cast(Ty); - bool Changed = false; - std::vector NewArgs; - for (auto Arg: Con->Args) { - auto NewArg = simplify(Arg); - if (NewArg != Arg) { - Changed = true; - } - NewArgs.push_back(NewArg); - } - return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : Ty; - } - - case TypeKind::TupleIndex: - { - auto Index = static_cast(Ty); - auto MaybeTuple = simplify(Index->Ty); - if (llvm::isa(MaybeTuple)) { - auto Tuple = static_cast(MaybeTuple); - if (Index->I >= Tuple->ElementTypes.size()) { - DE.add(Tuple, Index->I); - } else { - Ty = simplify(Tuple->ElementTypes[Index->I]); - } - } - break; - } - - } - - return Ty; - } - void Checker::join(TVar* TV, Type* Ty) { - Solution[TV] = Ty; + TV->set(Ty); propagateClasses(TV->Contexts, Ty); @@ -1275,16 +1372,6 @@ namespace bolt { } - void Checker::unifyError() { - DE.add( - simplify(OrigLeft), - simplify(OrigRight), - LeftPath, - RightPath, - Source - ); - } - class ArrowCursor { std::stack> Stack; @@ -1328,180 +1415,356 @@ namespace bolt { }; - bool Checker::unify(Type* A, Type* B) { + struct Unifier { - A = simplify(A); - B = simplify(B); + Checker& C; + CEqual* Constraint; - if (llvm::isa(A) && llvm::isa(B)) { - auto Var1 = static_cast(A); - auto Var2 = static_cast(B); - if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) { - if (Var1->Id != Var2->Id) { + // Internal state used by the unifier + ByteString CurrentFieldName; + TypePath LeftPath; + TypePath RightPath; + + Type* getLeft() const { + return Constraint->Left; + } + + Type* getRight() const { + return Constraint->Right; + } + + Node* getSource() const { + return Constraint->Source; + } + + bool unify(Type* A, Type* B); + + bool unifyField(Type* A, Type* B); + + bool unify() { + return unify(Constraint->Left, Constraint->Right); + } + + }; + + class UnificationFrame { + + Unifier& U; + Type* A; + Type* B; + bool DidSwap = false; + + public: + + UnificationFrame(Unifier& U, Type* A, Type* B): + U(U), A(U.C.simplifyType(A)), B(U.C.simplifyType(B)) {} + + void unifyError() { + U.C.DE.add( + U.C.simplifyType(U.Constraint->Left), + U.C.simplifyType(U.Constraint->Right), + U.LeftPath, + U.RightPath, + U.Constraint->Source + ); + } + + void pushLeft(TypeIndex I) { + if (DidSwap) { + U.RightPath.push_back(I); + } else { + U.LeftPath.push_back(I); + } + } + + void popLeft() { + if (DidSwap) { + U.RightPath.pop_back(); + } else { + U.LeftPath.pop_back(); + } + } + + void pushRight(TypeIndex I) { + if (DidSwap) { + U.LeftPath.push_back(I); + } else { + U.RightPath.push_back(I); + } + } + + void popRight() { + if (DidSwap) { + U.LeftPath.pop_back(); + } else { + U.RightPath.pop_back(); + } + } + + void swap() { + std::swap(A, B); + DidSwap = !DidSwap; + } + + bool unifyField() { + if (llvm::isa(A) && llvm::isa(B)) { + return true; + } + if (llvm::isa(B)) { + swap(); + } + if (llvm::isa(A)) { + auto Present = static_cast(B); + U.C.DE.add(U.CurrentFieldName, U.C.simplifyType(U.getLeft()), U.LeftPath, U.getSource()); + return false; + } + auto Present1 = static_cast(A); + auto Present2 = static_cast(B); + return U.unify(Present1->Ty, Present2->Ty); + } + + bool unify() { + + if (llvm::isa(A) && llvm::isa(B)) { + auto Var1 = static_cast(A); + auto Var2 = static_cast(B); + if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) { + if (Var1->Id != Var2->Id) { + unifyError(); + return false; + } + return true; + } + TVar* To; + TVar* From; + if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) { + To = Var1; + From = Var2; + } else { + // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification + // Either way, Var1, being Unification, is a good candidate for being unified away + To = Var2; + From = Var1; + } + if (From->Id != To->Id) { + U.C.join(From, To); + } + return true; + } + + if (llvm::isa(B)) { + swap(); + } + + if (llvm::isa(A)) { + + auto TV = static_cast(A); + + // Rigid type variables can never unify with antything else than what we + // have already handled in the previous if-statement, so issue an error. + if (TV->getVarKind() == VarKind::Rigid) { + unifyError(); + return false; + } + + // Occurs check + if (B->hasTypeVar(TV)) { + // NOTE Just like GHC, we just display an error message indicating that + // A cannot match B, e.g. a cannot match [a]. It looks much better + // than obsure references to an occurs check + unifyError(); + return false; + } + + U.C.join(TV, B); + + return true; + } + + if (llvm::isa(A) && llvm::isa(B)) { + auto C1 = ArrowCursor(static_cast(A), DidSwap ? U.RightPath : U.LeftPath); + auto C2 = ArrowCursor(static_cast(B), DidSwap ? U.LeftPath : U.RightPath); + bool Success = true; + for (;;) { + auto T1 = C1.next(); + auto T2 = C2.next(); + if (T1 == nullptr && T2 == nullptr) { + break; + } + if (T1 == nullptr || T2 == nullptr) { + unifyError(); + Success = false; + break; + } + if (!U.unify(T1, T2)) { + Success = false; + } + } + return Success; + /* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */ + /* return false; */ + /* } */ + /* auto Count = Arr1->ParamTypes.size(); */ + /* for (std::size_t I = 0; I < Count; I++) { */ + /* if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) { */ + /* return false; */ + /* } */ + /* } */ + /* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */ + } + + if (llvm::isa(A) && llvm::isa(B)) { + auto App1 = static_cast(A); + auto App2 = static_cast(B); + bool Success = true; + if (!U.unify(App1->Op, App2->Op)) { + Success = false; + } + if (!U.unify(App1->Arg, App2->Arg)) { + Success = false; + } + return Success; + } + + if (llvm::isa(B)) { + swap(); + } + + if (llvm::isa(A)) { + auto Arr = static_cast(A); + if (Arr->ParamTypes.empty()) { + auto Success = U.unify(Arr->ReturnType, B); + return Success; + } + } + + if (llvm::isa(A) && llvm::isa(B)) { + auto Tuple1 = static_cast(A); + auto Tuple2 = static_cast(B); + if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) { + unifyError(); + return false; + } + auto Count = Tuple1->ElementTypes.size(); + bool Success = true; + for (size_t I = 0; I < Count; I++) { + U.LeftPath.push_back(TypeIndex::forTupleElement(I)); + U.RightPath.push_back(TypeIndex::forTupleElement(I)); + if (!U.unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I])) { + Success = false; + } + U.LeftPath.pop_back(); + U.RightPath.pop_back(); + } + return Success; + } + + if (llvm::isa(A) || llvm::isa(B)) { + // Type(s) could not be simplified at the beginning of this function, + // so we have to re-visit the constraint when there is more information. + U.C.Queue.push_back(U.Constraint); + return true; + } + + // if (llvm::isa(A) && llvm::isa(B)) { + // auto Index1 = static_cast(A); + // auto Index2 = static_cast(B); + // return unify(Index1->Ty, Index2->Ty, Source); + // } + + if (llvm::isa(A) && llvm::isa(B)) { + auto Con1 = static_cast(A); + auto Con2 = static_cast(B); + if (Con1->Id != Con2->Id) { unifyError(); return false; } return true; } - TVar* To; - TVar* From; - if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) { - To = Var1; - From = Var2; - } else { - // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification - // Either way, Var1, being Unification, is a good candidate for being unified away - To = Var2; - From = Var1; - } - if (From->Id != To->Id) { - join(From, To); - } - return true; - } - if (llvm::isa(A)) { - - auto TV = static_cast(A); - - // Rigid type variables can never unify with antything else than what we - // have already handled in the previous if-statement, so issue an error. - if (TV->getVarKind() == VarKind::Rigid) { - unifyError(); - return false; + if (llvm::isa(A) && llvm::isa(B)) { + return true; } - // Occurs check - if (B->hasTypeVar(TV)) { - // NOTE Just like GHC, we just display an error message indicating that - // A cannot match B, e.g. a cannot match [a]. It looks much better - // than obsure references to an occurs check - unifyError(); - return false; - } - - join(TV, B); - - return true; - } - - if (llvm::isa(B)) { - return unify(B, A); - } - - if (llvm::isa(A) && llvm::isa(B)) { - auto C1 = ArrowCursor(static_cast(A), LeftPath); - auto C2 = ArrowCursor(static_cast(B), RightPath); - bool Success = true; - for (;;) { - auto T1 = C1.next(); - auto T2 = C2.next(); - if (T1 == nullptr && T2 == nullptr) { - break; + if (llvm::isa(A) && llvm::isa(B)) { + auto Field1 = static_cast(A); + auto Field2 = static_cast(B); + bool Success = true; + if (Field1->Name == Field2->Name) { + U.LeftPath.push_back(TypeIndex::forFieldType()); + U.RightPath.push_back(TypeIndex::forFieldType()); + U.CurrentFieldName = Field1->Name; + if (!U.unifyField(Field1->Ty, Field2->Ty)) { + Success = false; + } + U.LeftPath.pop_back(); + U.RightPath.pop_back(); + U.LeftPath.push_back(TypeIndex::forFieldRest()); + U.RightPath.push_back(TypeIndex::forFieldRest()); + if (!U.unify(Field1->RestTy, Field2->RestTy)) { + Success = false; + } + U.LeftPath.pop_back(); + U.RightPath.pop_back(); + return Success; } - if (T1 == nullptr || T2 == nullptr) { - unifyError(); - Success = false; - break; - } - if (!unify(T1, T2)) { + auto NewRestTy = new TVar(U.C.NextTypeVarId++, VarKind::Unification); + pushLeft(TypeIndex::forFieldRest()); + if (!U.unify(Field1->RestTy, new TField(Field2->Name, Field2->Ty, NewRestTy))) { Success = false; } - } - return Success; - /* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */ - /* return false; */ - /* } */ - /* auto Count = Arr1->ParamTypes.size(); */ - /* for (std::size_t I = 0; I < Count; I++) { */ - /* if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) { */ - /* return false; */ - /* } */ - /* } */ - /* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */ - } - - if (llvm::isa(A)) { - auto Arr = static_cast(A); - if (Arr->ParamTypes.empty()) { - return unify(Arr->ReturnType, B); - } - } - - if (llvm::isa(B)) { - return unify(B, A); - } - - if (llvm::isa(A) && llvm::isa(B)) { - auto Tuple1 = static_cast(A); - auto Tuple2 = static_cast(B); - if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) { - unifyError(); - return false; - } - auto Count = Tuple1->ElementTypes.size(); - bool Success = true; - for (size_t I = 0; I < Count; I++) { - LeftPath.push_back(TypeIndex::forTupleElement(I)); - RightPath.push_back(TypeIndex::forTupleElement(I)); - if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I])) { + popLeft(); + pushRight(TypeIndex::forFieldRest()); + if (!U.unify(new TField(Field1->Name, Field1->Ty, NewRestTy), Field2->RestTy)) { Success = false; } - LeftPath.pop_back(); - RightPath.pop_back(); + popRight(); + return Success; } - return Success; - } - if (llvm::isa(A) || llvm::isa(B)) { - Queue.push_back(C); - return true; - } - - // if (llvm::isa(A) && llvm::isa(B)) { - // auto Index1 = static_cast(A); - // auto Index2 = static_cast(B); - // return unify(Index1->Ty, Index2->Ty, Source); - // } - - if (llvm::isa(A) && llvm::isa(B)) { - auto Con1 = static_cast(A); - auto Con2 = static_cast(B); - if (Con1->Id != Con2->Id) { - unifyError(); - return false; + if (llvm::isa(A) && llvm::isa(B)) { + swap(); } - ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); - auto Count = Con1->Args.size(); - bool Success = true; - for (std::size_t I = 0; I < Count; I++) { - LeftPath.push_back(TypeIndex::forConArg(I)); - RightPath.push_back(TypeIndex::forConArg(I)); - if (!unify(Con1->Args[I], Con2->Args[I])) { + + if (llvm::isa(A) && llvm::isa(B)) { + auto Field = static_cast(A); + bool Success = true; + pushLeft(TypeIndex::forFieldType()); + U.CurrentFieldName = Field->Name; + if (!U.unifyField(Field->Ty, new TAbsent)) { Success = false; } - LeftPath.pop_back(); - RightPath.pop_back(); + popLeft(); + pushLeft(TypeIndex::forFieldRest()); + if (!U.unify(Field->RestTy, B)) { + Success = false; + } + popLeft(); + return Success; } - return Success; + + unifyError(); + return false; } - unifyError(); - return false; + }; + + bool Unifier::unify(Type* A, Type* B) { + UnificationFrame Frame { *this, A, B }; + return Frame.unify(); } - InferContext* Checker::lookupCall(Node* Source, SymbolPath Path) { - auto Def = Source->getScope()->lookup(Path); - auto Match = CallGraph.find(Def); - if (Match == CallGraph.end()) { - return nullptr; - } - return Match->second; + bool Unifier::unifyField(Type* A, Type* B) { + UnificationFrame Frame { *this, A, B }; + return Frame.unifyField(); } - Type* Checker::getType(TypedNode *Node) { - return Node->getType()->substitute(Solution); + void Checker::solveCEqual(CEqual* C) { + // std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; + Unifier A { *this, C }; + A.unify(); } + } diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 14225edfe..e362b7a87 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -13,6 +13,7 @@ #define ANSI_RESET "\u001b[0m" #define ANSI_BOLD "\u001b[1m" +#define ANSI_ITALIC "\u001b[3m" #define ANSI_UNDERLINE "\u001b[4m" #define ANSI_REVERSED "\u001b[7m" @@ -107,6 +108,16 @@ namespace bolt { return "'return'"; case NodeKind::TypeKeyword: return "'type'"; + case NodeKind::LetDeclaration: + return "a let-declaration"; + case NodeKind::CallExpression: + return "a call-expression"; + case NodeKind::InfixExpression: + return "an infix-expression"; + case NodeKind::ReferenceExpression: + return "a function or variable reference"; + case NodeKind::MatchExpression: + return "a match-expression"; default: ZEN_UNREACHABLE } @@ -151,16 +162,12 @@ namespace bolt { case TypeKind::Con: { auto Y = static_cast(Ty); - std::ostringstream Out; - if (!Y->DisplayName.empty()) { - Out << Y->DisplayName; - } else { - Out << "C" << Y->Id; - } - for (auto Arg: Y->Args) { - Out << " " << describe(Arg); - } - return Out.str(); + return Y->DisplayName; + } + case TypeKind::App: + { + auto Y = static_cast(Ty); + return describe(Y->Op) + " " + describe(Y->Arg); } case TypeKind::Tuple: { @@ -182,6 +189,94 @@ namespace bolt { auto Y = static_cast(Ty); return describe(Y->Ty) + "." + std::to_string(Y->I); } + case TypeKind::Nil: + return "{}"; + case TypeKind::Absent: + return "Abs"; + case TypeKind::Present: + { + auto Y = static_cast(Ty); + return describe(Y->Ty); + } + case TypeKind::Field: + { + auto Y = static_cast(Ty); + std::ostringstream out; + out << "{ " << Y->Name << ": " << describe(Y->Ty); + Ty = Y->RestTy; + while (Ty->getKind() == TypeKind::Field) { + auto Y = static_cast(Ty); + out << "; " + Y->Name + ": " + describe(Y->Ty); + Ty = Y->RestTy; + } + if (Ty->getKind() != TypeKind::Nil) { + out << "; " + describe(Ty); + } + out << " }"; + return out.str(); + } + } + } + + void writeForegroundANSI(Color C, std::ostream& Out) { + switch (C) { + case Color::None: + break; + case Color::Black: + Out << ANSI_FG_BLACK; + break; + case Color::White: + Out << ANSI_FG_WHITE; + break; + case Color::Red: + Out << ANSI_FG_RED; + break; + case Color::Yellow: + Out << ANSI_FG_YELLOW; + break; + case Color::Green: + Out << ANSI_FG_GREEN; + break; + case Color::Blue: + Out << ANSI_FG_BLUE; + break; + case Color::Cyan: + Out << ANSI_FG_CYAN; + break; + case Color::Magenta: + Out << ANSI_FG_MAGENTA; + break; + } + } + + void writeBackgroundANSI(Color C, std::ostream& Out) { + switch (C) { + case Color::None: + break; + case Color::Black: + Out << ANSI_BG_BLACK; + break; + case Color::White: + Out << ANSI_BG_WHITE; + break; + case Color::Red: + Out << ANSI_BG_RED; + break; + case Color::Yellow: + Out << ANSI_BG_YELLOW; + break; + case Color::Green: + Out << ANSI_BG_GREEN; + break; + case Color::Blue: + Out << ANSI_BG_BLUE; + break; + case Color::Cyan: + Out << ANSI_BG_CYAN; + break; + case Color::Magenta: + Out << ANSI_BG_MAGENTA; + break; } } @@ -195,91 +290,84 @@ namespace bolt { Out(Out) {} void ConsoleDiagnostics::setForegroundColor(Color C) { - if (EnableColors) { - switch (C) { - case Color::None: - break; - case Color::Black: - Out << ANSI_FG_BLACK; - break; - case Color::White: - Out << ANSI_FG_WHITE; - break; - case Color::Red: - Out << ANSI_FG_RED; - break; - case Color::Yellow: - Out << ANSI_FG_YELLOW; - break; - case Color::Green: - Out << ANSI_FG_GREEN; - break; - case Color::Blue: - Out << ANSI_FG_BLUE; - break; - case Color::Cyan: - Out << ANSI_FG_CYAN; - break; - case Color::Magenta: - Out << ANSI_FG_MAGENTA; - break; - } + ActiveStyle.setForegroundColor(C); + if (!EnableColors) { + return; } + writeForegroundANSI(C, Out); } - void ConsoleDiagnostics::setBackgroundColor(Color C) { - if (EnableColors) { - switch (C) { - case Color::None: - break; - case Color::Black: - Out << ANSI_BG_BLACK; - break; - case Color::White: - Out << ANSI_BG_WHITE; - break; - case Color::Red: - Out << ANSI_BG_RED; - break; - case Color::Yellow: - Out << ANSI_BG_YELLOW; - break; - case Color::Green: - Out << ANSI_BG_GREEN; - break; - case Color::Blue: - Out << ANSI_BG_BLUE; - break; - case Color::Cyan: - Out << ANSI_BG_CYAN; - break; - case Color::Magenta: - Out << ANSI_BG_MAGENTA; - break; - } + ActiveStyle.setBackgroundColor(C); + if (!EnableColors) { + return; + } + if (C == Color::None) { + Out << ANSI_RESET; + applyStyles(); + } + writeBackgroundANSI(C, Out); + } + + void ConsoleDiagnostics::applyStyles() { + if (ActiveStyle.isBold()) { + Out << ANSI_BOLD; + } + if (ActiveStyle.isUnderline()) { + Out << ANSI_UNDERLINE; + } + if (ActiveStyle.isItalic()) { + Out << ANSI_ITALIC; + } + if (ActiveStyle.hasBackgroundColor()) { + setBackgroundColor(ActiveStyle.getBackgroundColor()); + } + if (ActiveStyle.hasForegroundColor()) { + setForegroundColor(ActiveStyle.getForegroundColor()); } } void ConsoleDiagnostics::setBold(bool Enable) { + ActiveStyle.setBold(Enable); + if (!EnableColors) { + return; + } if (Enable) { Out << ANSI_BOLD; + } else { + Out << ANSI_RESET; + applyStyles(); } } void ConsoleDiagnostics::setItalic(bool Enable) { + ActiveStyle.setItalic(Enable); + if (!EnableColors) { + return; + } if (Enable) { - // TODO + Out << ANSI_ITALIC; + } else { + Out << ANSI_RESET; + applyStyles(); } } void ConsoleDiagnostics::setUnderline(bool Enable) { + ActiveStyle.setItalic(Enable); + if (!EnableColors) { + return; + } if (Enable) { Out << ANSI_UNDERLINE; + } else { + Out << ANSI_RESET; + applyStyles(); } } void ConsoleDiagnostics::resetStyles() { + ActiveStyle.reset(); if (EnableColors) { Out << ANSI_RESET; } @@ -391,8 +479,159 @@ namespace bolt { } void ConsoleDiagnostics::writeType(const Type* Ty) { + TypePath Path; + writeType(Ty, Path); + } + + void ConsoleDiagnostics::writeType(const Type* Ty, const TypePath& Underline) { + setForegroundColor(Color::Green); - write(describe(Ty)); + + class TypePrinter : public ConstTypeVisitor { + + TypePath Path; + ConsoleDiagnostics& W; + const TypePath& Underline; + + public: + + TypePrinter(ConsoleDiagnostics& W, const TypePath& Underline): + W(W), Underline(Underline) {} + + bool shouldUnderline() const { + return !Underline.empty() && Path == Underline; + } + + void enterType(const Type* Ty) override { + if (shouldUnderline()) { + W.setUnderline(true); + } + } + + void exitType(const Type* Ty) override { + if (shouldUnderline()) { + W.setUnderline(false); + } + } + + void visitAppType(const TApp *Ty) override { + auto Y = static_cast(Ty); + Path.push_back(TypeIndex::forAppOpType()); + visit(Y->Op); + Path.pop_back(); + W.write(" "); + Path.push_back(TypeIndex::forAppArgType()); + visit(Y->Arg); + Path.pop_back(); + } + + void visitVarType(const TVar* Ty) override { + if (Ty->getVarKind() == VarKind::Rigid) { + W.write(static_cast(Ty)->Name); + return; + } + W.write("a"); + W.write(Ty->Id); + } + + void visitConType(const TCon *Ty) override { + W.write(Ty->DisplayName); + } + + void visitArrowType(const TArrow* Ty) override { + W.write("("); + bool First = true; + std::size_t I = 0; + for (auto PT: Ty->ParamTypes) { + if (First) First = false; + else W.write(", "); + Path.push_back(TypeIndex::forArrowParamType(I++)); + visit(PT); + Path.pop_back(); + } + W.write(") -> "); + Path.push_back(TypeIndex::forArrowReturnType()); + visit(Ty->ReturnType); + Path.pop_back(); + } + + void visitTupleType(const TTuple *Ty) override { + W.write("("); + if (Ty->ElementTypes.size()) { + auto Iter = Ty->ElementTypes.begin(); + Path.push_back(TypeIndex::forTupleElement(0)); + visit(*Iter++); + Path.pop_back(); + std::size_t I = 1; + while (Iter != Ty->ElementTypes.end()) { + W.write(", "); + Path.push_back(TypeIndex::forTupleElement(I++)); + visit(*Iter++); + Path.pop_back(); + } + } + W.write(")"); + } + + void visitTupleIndexType(const TTupleIndex *Ty) override { + Path.push_back(TypeIndex::forTupleIndexType()); + visit(Ty->Ty); + Path.pop_back(); + W.write("."); + W.write(Ty->I); + } + + void visitNilType(const TNil *Ty) override { + W.write("{}"); + } + + void visitAbsentType(const TAbsent *Ty) override { + W.write("Abs"); + } + + void visitPresentType(const TPresent *Ty) override { + Path.push_back(TypeIndex::forPresentType()); + visit(Ty->Ty); + Path.pop_back(); + } + + void visitFieldType(const TField* Ty) override { + W.write("{ "); + W.write(Ty->Name); + W.write(": "); + Path.push_back(TypeIndex::forFieldType()); + visit(Ty->Ty); + Path.pop_back(); + auto Ty2 = Ty->RestTy; + Path.push_back(TypeIndex::forFieldRest()); + std::size_t I = 1; + while (Ty2->getKind() == TypeKind::Field) { + auto Y = static_cast(Ty2); + W.write("; "); + W.write(Y->Name); + W.write(": "); + Path.push_back(TypeIndex::forFieldType()); + visit(Y->Ty); + Path.pop_back(); + Ty2 = Y->RestTy; + Path.push_back(TypeIndex::forFieldRest()); + ++I; + } + if (Ty2->getKind() != TypeKind::Nil) { + W.write("; "); + visit(Ty); + } + W.write(" }"); + for (auto K = 0; K < I; K++) { + Path.pop_back(); + } + } + + }; + + TypePrinter P { *this, Underline }; + P.visit(Ty); + resetStyles(); } @@ -533,40 +772,51 @@ namespace bolt { case DiagnosticKind::UnificationError: { auto E = static_cast(D); + auto Left = E.OrigLeft->resolve(E.LeftPath); + auto Right = E.OrigRight->resolve(E.RightPath); writePrefix(E); - auto Left = E.Left->resolve(E.LeftPath); - auto Right = E.Right->resolve(E.RightPath); write("the types "); writeType(Left); write(" and "); writeType(Right); write(" failed to match\n\n"); - if (E.Source) { - writeNode(E.Source); - Out << "\n"; - } - if (!E.LeftPath.empty()) { - setForegroundColor(Color::Yellow); - setBold(true); - write(" info: "); - resetStyles(); - write("the type "); - writeType(Left); - write(" occurs in the full type "); - writeType(E.Left); - write("\n\n"); - } - if (!E.RightPath.empty()) { - setForegroundColor(Color::Yellow); - setBold(true); - write(" info: "); - resetStyles(); - write("the type "); - writeType(Right); - write(" occurs in the full type "); - writeType(E.Right); - write("\n\n"); - } + setForegroundColor(Color::Yellow); + setBold(true); + write(" info: "); + resetStyles(); + write("due to an equality constraint on "); + write(describe(E.Source->getKind())); + write(":\n\n"); + write(" - left type "); + writeType(E.OrigLeft, E.LeftPath); + write("\n"); + write(" - right type "); + writeType(E.OrigRight, E.RightPath); + write("\n\n"); + writeNode(E.Source); + write("\n"); + // if (E.Left != E.OrigLeft) { + // setForegroundColor(Color::Yellow); + // setBold(true); + // write(" info: "); + // resetStyles(); + // write("the type "); + // writeType(E.Left); + // write(" occurs in the full type "); + // writeType(E.OrigLeft); + // write("\n\n"); + // } + // if (E.Right != E.OrigRight) { + // setForegroundColor(Color::Yellow); + // setBold(true); + // write(" info: "); + // resetStyles(); + // write("the type "); + // writeType(E.Right); + // write(" occurs in the full type "); + // writeType(E.OrigRight); + // write("\n\n"); + // } break; } @@ -634,6 +884,18 @@ namespace bolt { break; } + case DiagnosticKind::FieldNotFound: + { + auto E = static_cast(D); + writePrefix(E); + write("the field '"); + write(E.Name); + write("' was required in one type but not found in another\n\n"); + writeNode(E.Source); + write("\n"); + break; + } + } } diff --git a/src/Parser.cc b/src/Parser.cc index 76645f687..0c60f35c0 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -473,6 +473,75 @@ after_tuple_element: return new MatchExpression(static_cast(T0), Value, BlockStart, Cases); } + RecordExpression* Parser::parseRecordExpression() { + auto LBrace = expectToken(); + if (!LBrace) { + return nullptr; + } + RBrace* RBrace; + auto T1 = Tokens.peek(); + std::vector> Fields; + if (T1->getKind() == NodeKind::RBrace) { + Tokens.get(); + RBrace = static_cast(T1); + } else { + for (;;) { + auto Name = expectToken(); + if (!Name) { + LBrace->unref(); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + return nullptr; + } + auto Equals = expectToken(); + if (!Equals) { + LBrace->unref(); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + return nullptr; + } + auto E = parseExpression(); + if (!E) { + LBrace->unref(); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + Equals->unref(); + return nullptr; + } + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::Comma) { + Tokens.get(); + Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, static_cast(T2))); + } else if (T2->getKind() == NodeKind::RBrace) { + Tokens.get(); + RBrace = static_cast(T2); + Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, nullptr)); + break; + } else { + DE.add(File, T2, std::vector { NodeKind::Comma, NodeKind::RBrace }); + LBrace->unref(); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + Equals->unref(); + E->unref(); + return nullptr; + } + } + } + return new RecordExpression { LBrace, Fields, RBrace }; + } + Expression* Parser::parsePrimitiveExpression() { auto T0 = Tokens.peek(); switch (T0->getKind()) { @@ -562,9 +631,11 @@ after_tuple_elements: case NodeKind::StringLiteral: Tokens.get(); return new ConstantExpression(static_cast(T0)); + case NodeKind::LBrace: + return parseRecordExpression(); default: // Tokens.get(); - DE.add(File, T0, std::vector { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen, NodeKind::IntegerLiteral, NodeKind::StringLiteral }); + DE.add(File, T0, std::vector { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen, NodeKind::LBrace, NodeKind::IntegerLiteral, NodeKind::StringLiteral }); return nullptr; } } @@ -603,7 +674,12 @@ finish: std::vector Args; for (;;) { auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || T1->getKind() == NodeKind::Comma || ExprOperators.isInfix(T1)) { + if (T1->getKind() == NodeKind::LineFoldEnd + || T1->getKind() == NodeKind::RParen + || T1->getKind() == NodeKind::RBrace + || T1->getKind() == NodeKind::BlockStart + || T1->getKind() == NodeKind::Comma + || ExprOperators.isInfix(T1)) { break; } auto Arg = parsePrimitiveExpression(); diff --git a/src/Types.cc b/src/Types.cc index 10ab6e089..80f85d836 100644 --- a/src/Types.cc +++ b/src/Types.cc @@ -28,7 +28,6 @@ namespace bolt { return false; } switch (Kind) { - case TypeIndexKind::ConArg: case TypeIndexKind::ArrowParamType: case TypeIndexKind::TupleElement: return I == Other.I; @@ -41,6 +40,9 @@ namespace bolt { switch (Kind) { case TypeIndexKind::End: break; + case TypeIndexKind::AppOpType: + Kind = TypeIndexKind::AppArgType; + break; case TypeIndexKind::ArrowParamType: { auto Arrow = llvm::cast(Ty); @@ -51,19 +53,16 @@ namespace bolt { } break; } + case TypeIndexKind::FieldType: + Kind = TypeIndexKind::FieldRestType; + break; + case TypeIndexKind::FieldRestType: + case TypeIndexKind::TupleIndexType: + case TypeIndexKind::PresentType: + case TypeIndexKind::AppArgType: case TypeIndexKind::ArrowReturnType: Kind = TypeIndexKind::End; break; - case TypeIndexKind::ConArg: - { - auto Con = llvm::cast(Ty); - if (I+1 < Con->Args.size()) { - ++I; - } else { - Kind = TypeIndexKind::End; - } - break; - } case TypeIndexKind::TupleElement: { auto Tuple = llvm::cast(Ty); @@ -77,6 +76,95 @@ namespace bolt { } } + Type* Type::rewrite(std::function Fn, bool Recursive) { + auto Ty2 = Fn(this); + if (!Recursive && this != Ty2) { + return Ty2; + } + switch (Kind) { + case TypeKind::Var: + return Ty2; + case TypeKind::Arrow: + { + auto Arrow = static_cast(Ty2); + bool Changed = false; + std::vector NewParamTypes; + for (auto Ty: Arrow->ParamTypes) { + auto NewParamType = Ty->rewrite(Fn); + if (NewParamType != Ty) { + Changed = true; + } + NewParamTypes.push_back(NewParamType); + } + auto NewRetTy = Arrow->ReturnType->rewrite(Fn); + if (NewRetTy != Arrow->ReturnType) { + Changed = true; + } + return Changed ? new TArrow(NewParamTypes, NewRetTy) : Ty2; + } + case TypeKind::Con: + return Ty2; + case TypeKind::App: + { + auto App = static_cast(Ty2); + auto NewOp = App->Op->rewrite(Fn); + auto NewArg = App->Arg->rewrite(Fn); + if (NewOp == App->Op && NewArg == App->Arg) { + return App; + } + return new TApp(NewOp, NewArg); + } + case TypeKind::TupleIndex: + { + auto Tuple = static_cast(Ty2); + auto NewTy = Tuple->Ty->rewrite(Fn); + return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple; + } + case TypeKind::Tuple: + { + auto Tuple = static_cast(Ty2); + bool Changed = false; + std::vector NewElementTypes; + for (auto Ty: Tuple->ElementTypes) { + auto NewElementType = Ty->rewrite(Fn); + if (NewElementType != Ty) { + Changed = true; + } + NewElementTypes.push_back(NewElementType); + } + return Changed ? new TTuple(NewElementTypes) : Ty2; + } + case TypeKind::Nil: + return Ty2; + case TypeKind::Absent: + return Ty2; + case TypeKind::Field: + { + auto Field = static_cast(Ty2); + bool Changed = false; + auto NewTy = Field->Ty->rewrite(Fn); + if (NewTy != Field->Ty) { + Changed = true; + } + auto NewRestTy = Field->RestTy->rewrite(Fn); + if (NewRestTy != Field->RestTy) { + Changed = true; + } + return Changed ? new TField(Field->Name, NewTy, NewRestTy) : Ty2; + } + case TypeKind::Present: + { + auto Present = static_cast(Ty2); + auto NewTy = Present->Ty->rewrite(Fn); + if (NewTy == Present->Ty) { + return Ty2; + } + return new TPresent(NewTy); + } + } + + } + void Type::addTypeVars(TVSet& TVs) { switch (Kind) { case TypeKind::Var: @@ -92,11 +180,12 @@ namespace bolt { break; } case TypeKind::Con: + break; + case TypeKind::App: { - auto Con = static_cast(this); - for (auto Ty: Con->Args) { - Ty->addTypeVars(TVs); - } + auto App = static_cast(this); + App->Op->addTypeVars(TVs); + App->Arg->addTypeVars(TVs); break; } case TypeKind::TupleIndex: @@ -113,6 +202,23 @@ namespace bolt { } break; } + case TypeKind::Nil: + break; + case TypeKind::Field: + { + auto Field = static_cast(this); + Field->Ty->addTypeVars(TVs); + Field->Ty->addTypeVars(TVs); + break; + } + case TypeKind::Present: + { + auto Present = static_cast(this); + Present->Ty->addTypeVars(TVs); + break; + } + case TypeKind::Absent: + break; } } @@ -131,14 +237,11 @@ namespace bolt { return Arrow->ReturnType->hasTypeVar(TV); } case TypeKind::Con: - { - auto Con = static_cast(this); - for (auto Ty: Con->Args) { - if (Ty->hasTypeVar(TV)) { - return true; - } - } return false; + case TypeKind::App: + { + auto App = static_cast(this); + return App->Op->hasTypeVar(TV) && App->Arg->hasTypeVar(TV); } case TypeKind::TupleIndex: { @@ -155,173 +258,181 @@ namespace bolt { } return false; } + case TypeKind::Nil: + return false; + case TypeKind::Field: + { + auto Field = static_cast(this); + return Field->Ty->hasTypeVar(TV) || Field->RestTy->hasTypeVar(TV); + } + case TypeKind::Present: + { + auto Present = static_cast(this); + return Present->Ty->hasTypeVar(TV); + } + case TypeKind::Absent: + return false; } } + Type* Type::solve() { + return rewrite([](auto Ty) { + if (Ty->getKind() == TypeKind::Var) { + return static_cast(Ty)->find(); + } + return Ty; + }); + } + Type* Type::substitute(const TVSub &Sub) { - switch (Kind) { - case TypeKind::Var: - { - auto TV = static_cast(this); + return rewrite([&](auto Ty) { + if (llvm::isa(Ty)) { + auto TV = static_cast(Ty); auto Match = Sub.find(TV); - return Match != Sub.end() ? Match->second->substitute(Sub) : this; + return Match != Sub.end() ? Match->second->substitute(Sub) : Ty; } - case TypeKind::Arrow: - { - auto Arrow = static_cast(this); - bool Changed = false; - std::vector NewParamTypes; - for (auto Ty: Arrow->ParamTypes) { - auto NewParamType = Ty->substitute(Sub); - if (NewParamType != Ty) { - Changed = true; - } - NewParamTypes.push_back(NewParamType); - } - auto NewRetTy = Arrow->ReturnType->substitute(Sub) ; - if (NewRetTy != Arrow->ReturnType) { - Changed = true; - } - return Changed ? new TArrow(NewParamTypes, NewRetTy) : this; - } - case TypeKind::Con: - { - auto Con = static_cast(this); - bool Changed = false; - std::vector NewArgs; - for (auto Arg: Con->Args) { - auto NewArg = Arg->substitute(Sub); - if (NewArg != Arg) { - Changed = true; - } - NewArgs.push_back(NewArg); - } - return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this; - } - case TypeKind::TupleIndex: - { - auto Tuple = static_cast(this); - auto NewTy = Tuple->Ty->substitute(Sub); - return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple; - } - case TypeKind::Tuple: - { - auto Tuple = static_cast(this); - bool Changed = false; - std::vector NewElementTypes; - for (auto Ty: Tuple->ElementTypes) { - auto NewElementType = Ty->substitute(Sub); - if (NewElementType != Ty) { - Changed = true; - } - NewElementTypes.push_back(NewElementType); - } - return Changed ? new TTuple(NewElementTypes) : this; - } - } + return Ty; + }); } Type* Type::resolve(const TypeIndex& Index) const noexcept { switch (Index.Kind) { - case TypeIndexKind::ConArg: - return llvm::cast(this)->Args[Index.I]; + case TypeIndexKind::PresentType: + return llvm::cast(this)->Ty; + case TypeIndexKind::AppOpType: + return llvm::cast(this)->Op; + case TypeIndexKind::AppArgType: + return llvm::cast(this)->Arg; + case TypeIndexKind::TupleIndexType: + return llvm::cast(this)->Ty; case TypeIndexKind::TupleElement: return llvm::cast(this)->ElementTypes[Index.I]; case TypeIndexKind::ArrowParamType: return llvm::cast(this)->ParamTypes[Index.I]; case TypeIndexKind::ArrowReturnType: return llvm::cast(this)->ReturnType; + case TypeIndexKind::FieldType: + return llvm::cast(this)->Ty; + case TypeIndexKind::FieldRestType: + return llvm::cast(this)->RestTy; case TypeIndexKind::End: ZEN_UNREACHABLE } ZEN_UNREACHABLE } - bool Type::operator==(const Type& Other) const noexcept { - switch (Kind) { - case TypeKind::Var: - if (Other.Kind != TypeKind::Var) { - return false; - } - return static_cast(this)->Id == static_cast(Other).Id; - case TypeKind::Tuple: - { - if (Other.Kind != TypeKind::Tuple) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - if (A.ElementTypes.size() != B.ElementTypes.size()) { - return false; - } - for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) { - if (*T1 != *T2) { - return false; - } - } - return true; - } - case TypeKind::TupleIndex: - { - if (Other.Kind != TypeKind::TupleIndex) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - return A.I == B.I && *A.Ty == *B.Ty; - } - case TypeKind::Con: - { - if (Other.Kind != TypeKind::Con) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - if (A.Id != B.Id) { - return false; - } - if (A.Args.size() != B.Args.size()) { - return false; - } - for (auto [T1, T2]: zen::zip(A.Args, B.Args)) { - if (*T1 != *T2) { - return false; - } - } - return true; - } - case TypeKind::Arrow: - { - // FIXME Do we really need to 'curry' this type? - if (Other.Kind != TypeKind::Arrow) { - return false; - } - auto A = static_cast(*this); - auto B = static_cast(Other); - /* ArrowCursor C1 { &A }; */ - /* ArrowCursor C2 { &B }; */ - /* for (;;) { */ - /* auto T1 = C1.next(); */ - /* auto T2 = C2.next(); */ - /* if (T1 == nullptr && T2 == nullptr) { */ - /* break; */ - /* } */ - /* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */ - /* return false; */ - /* } */ - /* } */ - if (A.ParamTypes.size() != B.ParamTypes.size()) { - return false; - } - for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) { - if (*T1 != *T2) { - return false; - } - } - return A.ReturnType != B.ReturnType; - } - } - } + // bool Type::operator==(const Type& Other) const noexcept { + // switch (Kind) { + // case TypeKind::Var: + // if (Other.Kind != TypeKind::Var) { + // return false; + // } + // return static_cast(this)->Id == static_cast(Other).Id; + // case TypeKind::Tuple: + // { + // if (Other.Kind != TypeKind::Tuple) { + // return false; + // } + // auto A = static_cast(*this); + // auto B = static_cast(Other); + // if (A.ElementTypes.size() != B.ElementTypes.size()) { + // return false; + // } + // for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) { + // if (*T1 != *T2) { + // return false; + // } + // } + // return true; + // } + // case TypeKind::TupleIndex: + // { + // if (Other.Kind != TypeKind::TupleIndex) { + // return false; + // } + // auto A = static_cast(*this); + // auto B = static_cast(Other); + // return A.I == B.I && *A.Ty == *B.Ty; + // } + // case TypeKind::Con: + // { + // if (Other.Kind != TypeKind::Con) { + // return false; + // } + // auto A = static_cast(*this); + // auto B = static_cast(Other); + // if (A.Id != B.Id) { + // return false; + // } + // if (A.Args.size() != B.Args.size()) { + // return false; + // } + // for (auto [T1, T2]: zen::zip(A.Args, B.Args)) { + // if (*T1 != *T2) { + // return false; + // } + // } + // return true; + // } + // case TypeKind::Arrow: + // { + // if (Other.Kind != TypeKind::Arrow) { + // return false; + // } + // auto A = static_cast(*this); + // auto B = static_cast(Other); + // /* ArrowCursor C1 { &A }; */ + // /* ArrowCursor C2 { &B }; */ + // /* for (;;) { */ + // /* auto T1 = C1.next(); */ + // /* auto T2 = C2.next(); */ + // /* if (T1 == nullptr && T2 == nullptr) { */ + // /* break; */ + // /* } */ + // /* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */ + // /* return false; */ + // /* } */ + // /* } */ + // if (A.ParamTypes.size() != B.ParamTypes.size()) { + // return false; + // } + // for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) { + // if (*T1 != *T2) { + // return false; + // } + // } + // return A.ReturnType != B.ReturnType; + // } + // case TypeKind::Absent: + // if (Other.Kind != TypeKind::Absent) { + // return false; + // } + // return true; + // case TypeKind::Nil: + // if (Other.Kind != TypeKind::Nil) { + // return false; + // } + // return true; + // case TypeKind::Present: + // { + // if (Other.Kind != TypeKind::Present) { + // return false; + // } + // auto A = static_cast(*this); + // auto B = static_cast(Other); + // return *A.Ty == *B.Ty; + // } + // case TypeKind::Field: + // { + // if (Other.Kind != TypeKind::Field) { + // return false; + // } + // auto A = static_cast(*this); + // auto B = static_cast(Other); + // return *A.Ty == *B.Ty && *A.RestTy == *B.RestTy; + // } + // } + // } TypeIterator Type::begin() { return TypeIterator { this, getStartIndex() }; @@ -333,14 +444,6 @@ namespace bolt { TypeIndex Type::getStartIndex() { switch (Kind) { - case TypeKind::Con: - { - auto Con = static_cast(this); - if (Con->Args.empty()) { - return TypeIndex(TypeIndexKind::End); - } - return TypeIndex::forConArg(0); - } case TypeKind::Arrow: { auto Arrow = static_cast(this); @@ -357,6 +460,8 @@ namespace bolt { } return TypeIndex::forTupleElement(0); } + case TypeKind::Field: + return TypeIndex::forFieldType(); default: return TypeIndex(TypeIndexKind::End); } @@ -366,4 +471,25 @@ namespace bolt { return TypeIndex(TypeIndexKind::End); } + + inline Type* TVar::find() { + TVar* Curr = this; + for (;;) { + auto Keep = Curr->Parent; + if (Keep->getKind() != TypeKind::Var || Keep == Curr) { + return Keep; + } + auto TV = static_cast(Keep); + Curr->Parent = TV->Parent; + Curr = TV; + } + } + + void TVar::set(Type* Ty) { + auto Root = find(); + // It is not possible to set a solution twice. + ZEN_ASSERT(Root->getKind() == TypeKind::Var); + static_cast(Root)->Parent = Ty; + } + }