From 5ba2aafc68acb56a330aaa169d347b3e37dda593 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Fri, 21 Jun 2024 00:18:44 +0200 Subject: [PATCH] Switch to bidirectional type-checker and many more improvements --- CMakeLists.txt | 23 +- include/bolt/CST.hpp | 132 +- include/bolt/CSTVisitor.hpp | 23 +- include/bolt/Checker.hpp | 359 ++--- include/bolt/Common.hpp | 2 + include/bolt/ConsolePrinter.hpp | 9 +- include/bolt/Diagnostics.hpp | 168 +-- include/bolt/Either.hpp | 124 ++ include/bolt/Program.hpp | 6 +- include/bolt/Type.hpp | 824 +++-------- src/CST.cc | 13 +- src/Checker.cc | 2315 ++++++------------------------- src/ConsolePrinter.cc | 371 +---- src/Diagnostics.cc | 3 +- src/Evaluator.cc | 22 +- src/LLVMCodeGen.cc | 5 +- src/LLVMCodeGen.hpp | 3 +- src/Parser.cc | 2 + src/Program.cc | 3 + src/Scanner.cc | 2 - src/Type.cc | 102 ++ src/Types.cc | 336 ----- src/main.cc | 7 +- x.py | 7 +- 24 files changed, 1171 insertions(+), 3690 deletions(-) create mode 100644 include/bolt/Either.hpp create mode 100644 src/Program.cc create mode 100644 src/Type.cc delete mode 100644 src/Types.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index b2795fff4..22fdb8241 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,11 +1,12 @@ -cmake_minimum_required(VERSION 3.10) +cmake_minimum_required(VERSION 3.20) project(Bolt C CXX) set(CMAKE_CXX_STANDARD 20) add_subdirectory(deps/zen EXCLUDE_FROM_ALL) +add_subdirectory(deps/llvm-project/llvm EXCLUDE_FROM_ALL) set(ICU_DIR "${CMAKE_CURRENT_SOURCE_DIR}/build/icu/install") set(ICU_CFLAGS "-DUNISTR_FROM_CHAR_EXPLICIT=explicit -DUNISTR_FROM_STRING_EXPLICIT=explicit -DU_NO_DEFAULT_INCLUDE_UTF_HEADERS=1 -DU_HIDE_OBSOLETE_UTF_OLD_H=1") @@ -17,7 +18,7 @@ if (CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo" OR CMAKE_BUILD_TYPE STREQUAL "Deb set(BOLT_DEBUG ON) endif() -find_package(LLVM 18.1.0 REQUIRED) +#find_package(LLVM 19.0 REQUIRED) add_library( BoltCore @@ -27,10 +28,11 @@ add_library( src/ConsolePrinter.cc src/Scanner.cc src/Parser.cc - src/Types.cc + src/Type.cc src/Checker.cc src/Evaluator.cc src/Scope.cc + src/Program.cc ) target_link_directories( BoltCore @@ -41,6 +43,7 @@ target_compile_options( BoltCore PUBLIC -Werror + -fno-exceptions ${ICU_CFLAGS} ) @@ -68,16 +71,18 @@ add_library( BoltLLVM src/LLVMCodeGen.cc ) -llvm_map_components_to_libnames(llvm_libs support core irreader) -target_include_directories(BoltLLVM PRIVATE ${LLVM_INCLUDE_DIRS}) -separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) -target_compile_definitions(BoltLLVM PRIVATE ${LLVM_DEFINITIONS_LIST}) - target_link_libraries( BoltLLVM PUBLIC BoltCore - ${llvm_libs} + LLVMCore + LLVMTarget +) +target_include_directories( + BoltLLVM + PUBLIC + deps/llvm-project/llvm/include # FIXME this is a hack + ${CMAKE_BINARY_DIR}/deps/llvm-project/llvm/include # FIXME this is a hack ) add_executable( diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 4a4a15f43..6106b86f6 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -1,24 +1,23 @@ #ifndef BOLT_CST_HPP #define BOLT_CST_HPP +#include #include -#include #include #include #include +#include -#include "bolt/Common.hpp" #include "zen/config.hpp" +#include "bolt/Common.hpp" #include "bolt/Integer.hpp" #include "bolt/String.hpp" #include "bolt/ByteString.hpp" +#include "bolt/Type.hpp" namespace bolt { -class Type; -class InferContext; - class Token; class SourceFile; class Scope; @@ -1265,6 +1264,8 @@ public: return Ty; } + static bool classof(Node* N); + }; class TypeExpression : public TypedNode, AnnotationContainer { @@ -1273,6 +1274,19 @@ protected: inline TypeExpression(NodeKind Kind, std::vector Annotations = {}): TypedNode(Kind), AnnotationContainer(Annotations) {} +public: + + static bool classof(Node* N) { + return N->getKind() == NodeKind::ReferenceTypeExpression + || N->getKind() == NodeKind::AppTypeExpression + || N->getKind() == NodeKind::NestedTypeExpression + || N->getKind() == NodeKind::ArrowTypeExpression + || N->getKind() == NodeKind::VarTypeExpression + || N->getKind() == NodeKind::TupleTypeExpression + || N->getKind() == NodeKind::RecordTypeExpression + || N->getKind() == NodeKind::QualifiedTypeExpression; + } + }; class ConstraintExpression : public Node { @@ -1740,6 +1754,21 @@ protected: inline Expression(NodeKind Kind, std::vector Annotations = {}): TypedNode(Kind), AnnotationContainer(Annotations) {} +public: + + static bool classof(Node* N) { + return N->getKind() == NodeKind::ReferenceExpression + || N->getKind() == NodeKind::NestedExpression + || N->getKind() == NodeKind::CallExpression + || N->getKind() == NodeKind::TupleExpression + || N->getKind() == NodeKind::InfixExpression + || N->getKind() == NodeKind::RecordExpression + || N->getKind() == NodeKind::MatchExpression + || N->getKind() == NodeKind::MemberExpression + || N->getKind() == NodeKind::LiteralExpression + || N->getKind() == NodeKind::PrefixExpression; + } + }; class ReferenceExpression : public Expression { @@ -1780,8 +1809,6 @@ class MatchCase : public Node { public: - InferContext* Ctx; - class Pattern* Pattern; class RArrowAlt* RArrowAlt; class Expression* Expression; @@ -2117,6 +2144,14 @@ protected: inline Statement(NodeKind Type, std::vector Annotations = {}): Node(Type), AnnotationContainer(Annotations) {} +public: + + static bool classof(Node* N) { + return N->getKind() == NodeKind::ExpressionStatement + || N->getKind() == NodeKind::ReturnStatement + || N->getKind() == NodeKind::IfStatement; + } + }; class ExpressionStatement : public Statement { @@ -2192,14 +2227,14 @@ class ReturnStatement : public Statement { public: class ReturnKeyword* ReturnKeyword; - class Expression* Expression; + Expression* E; ReturnStatement( class ReturnKeyword* ReturnKeyword, class Expression* Expression ): Statement(NodeKind::ReturnStatement), ReturnKeyword(ReturnKeyword), - Expression(Expression) {} + E(Expression) {} ReturnStatement( std::vector Annotations, @@ -2207,11 +2242,19 @@ public: class Expression* Expression ): Statement(NodeKind::ReturnStatement, Annotations), ReturnKeyword(ReturnKeyword), - Expression(Expression) {} + E(Expression) {} Token* getFirstToken() const override; Token* getLastToken() const override; + bool hasExpression() const { + return E; + } + + Expression* getExpression() { + return E; + } + }; class TypeAssert : public Node { @@ -2297,7 +2340,44 @@ public: }; -class FunctionDeclaration : public TypedNode, public AnnotationContainer { +class Declaration : public TypedNode, public AnnotationContainer { + + std::optional Scm; + +protected: + + + inline Declaration(NodeKind Kind, std::vector Annotations = {}): + TypedNode(Kind), AnnotationContainer(Annotations) {} + +public: + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::VariantDeclaration + || N->getKind() == NodeKind::RecordDeclaration + || N->getKind() == NodeKind::VariantDeclaration + || N->getKind() == NodeKind::PrefixFunctionDeclaration + || N->getKind() == NodeKind::InfixFunctionDeclaration + || N->getKind() == NodeKind::SuffixFunctionDeclaration + || N->getKind() == NodeKind::NamedFunctionDeclaration; + } + + const TypeScheme& getScheme() const { + ZEN_ASSERT(Scm.has_value()); + return *Scm; + } + + bool hasScheme() const { + return Scm.has_value(); + } + + void setScheme(TypeScheme NewScm) { + Scm = NewScm; + } + +}; + +class FunctionDeclaration : public Declaration { Scope* TheScope = nullptr; @@ -2305,10 +2385,9 @@ public: bool IsCycleActive = false; bool Visited = false; - InferContext* Ctx; FunctionDeclaration(NodeKind Kind, std::vector Annotations = {}): - TypedNode(Kind), AnnotationContainer(Annotations) {} + Declaration(Kind, Annotations) {} virtual bool isPublic() const = 0; @@ -2604,7 +2683,7 @@ public: }; -class VariableDeclaration : public TypedNode, public AnnotationContainer { +class VariableDeclaration : public Declaration { public: class PubKeyword* PubKeyword; @@ -2625,8 +2704,7 @@ public: class Pattern* Pattern, class TypeAssert* TypeAssert, LetBody* Body - ): TypedNode(NodeKind::VariableDeclaration), - AnnotationContainer(Annotations), + ): Declaration(NodeKind::VariableDeclaration, Annotations), PubKeyword(PubKeyword), ForeignKeyword(ForeignKeyword), LetKeyword(LetKeyword), @@ -2651,6 +2729,15 @@ public: return N->getKind() == NodeKind::VariableDeclaration; } + bool hasExpression() const { + return Body; + } + + Expression* getExpression() { + ZEN_ASSERT(Body->getKind() == NodeKind::LetExprBody); + return static_cast(Body)->Expression; + } + }; class InstanceDeclaration : public Node { @@ -2739,11 +2826,9 @@ public: }; -class RecordDeclaration : public Node { +class RecordDeclaration : public Declaration { public: - InferContext* Ctx; - class PubKeyword* PubKeyword; class StructKeyword* StructKeyword; IdentifierAlt* Name; @@ -2758,7 +2843,7 @@ public: std::vector Vars, class BlockStart* BlockStart, std::vector Fields - ): Node(NodeKind::RecordDeclaration), + ): Declaration(NodeKind::RecordDeclaration), PubKeyword(PubKeyword), StructKeyword(StructKeyword), Name(Name), @@ -2818,11 +2903,9 @@ public: }; -class VariantDeclaration : public Node { +class VariantDeclaration : public Declaration { public: - InferContext* Ctx; - class PubKeyword* PubKeyword; class EnumKeyword* EnumKeyword; class IdentifierAlt* Name; @@ -2837,7 +2920,7 @@ public: std::vector TVs, class BlockStart* BlockStart, std::vector Members - ): Node(NodeKind::VariantDeclaration), + ): Declaration(NodeKind::VariantDeclaration), PubKeyword(PubKeyword), EnumKeyword(EnumKeyword), Name(Name), @@ -2857,7 +2940,6 @@ class SourceFile : public Node { public: TextFile File; - InferContext* Ctx; std::vector Elements; diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index 3d3fd4ae1..a359a6786 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -1,9 +1,6 @@ #pragma once -#include "CST.hpp" -#include "zen/config.hpp" - #include "bolt/CST.hpp" namespace bolt { @@ -18,6 +15,10 @@ public: case NodeKind::name: \ return static_cast(this)->visit ## name(static_cast(N)); +#define BOLT_VISIT(node) static_cast(this)->visit(node) +#define BOLT_VISIT_SYMBOL(node) static_cast(this)->dispatchSymbol(node) +#define BOLT_VISIT_OPERATOR(node) static_cast(this)->dispatchOperator(node) + switch (N->getKind()) { BOLT_GEN_CASE(VBar) BOLT_GEN_CASE(Equals) @@ -123,13 +124,13 @@ public: void dispatchSymbol(const Symbol& S) { switch (S.getKind()) { case NodeKind::Identifier: - visit(S.asIdentifier()); + BOLT_VISIT(S.asIdentifier()); break; case NodeKind::IdentifierAlt: - visit(S.asIdentifierAlt()); + BOLT_VISIT(S.asIdentifierAlt()); break; case NodeKind::WrappedOperator: - visit(S.asWrappedOperator()); + BOLT_VISIT(S.asWrappedOperator()); break; default: ZEN_UNREACHABLE @@ -139,10 +140,10 @@ public: void dispatchOperator(const Operator& O) { switch (O.getKind()) { case NodeKind::VBar: - visit(O.asVBar()); + BOLT_VISIT(O.asVBar()); break; case NodeKind::CustomOperator: - visit(O.asCustomOperator()); + BOLT_VISIT(O.asCustomOperator()); break; default: ZEN_UNREACHABLE @@ -698,10 +699,6 @@ public: } } -#define BOLT_VISIT(node) static_cast(this)->visit(node) -#define BOLT_VISIT_SYMBOL(node) static_cast(this)->dispatchSymbol(node) -#define BOLT_VISIT_OPERATOR(node) static_cast(this)->dispatchOperator(node) - void visitEachChild(VBar* N) { } @@ -1152,7 +1149,7 @@ public: BOLT_VISIT(A); } BOLT_VISIT(N->ReturnKeyword); - BOLT_VISIT(N->Expression); + BOLT_VISIT(N->E); } void visitEachChild(IfStatement* N) { diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index 9f10a22a2..dc5a3913f 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -3,340 +3,145 @@ #include #include -#include -#include +#include #include "zen/tuple_hash.hpp" #include "bolt/ByteString.hpp" -#include "bolt/Common.hpp" #include "bolt/CST.hpp" +#include "bolt/DiagnosticEngine.hpp" #include "bolt/Type.hpp" -#include "bolt/Support/Graph.hpp" namespace bolt { -std::string describe(const Type* Ty); // For debugging only - -enum class SymKind { - Type, - Var, +enum class ConstraintKind { + TypesEqual, }; -class DiagnosticEngine; +class Constraint { -class Constraint; - -using ConstraintSet = std::vector; - -enum class SchemeKind : unsigned char { - Forall, -}; - -class Scheme { - - const SchemeKind Kind; + ConstraintKind Kind; protected: - inline Scheme(SchemeKind Kind): + Constraint(ConstraintKind Kind): Kind(Kind) {} public: - inline SchemeKind getKind() const noexcept { + inline ConstraintKind getKind() const { return Kind; } - virtual ~Scheme() {} - }; -class Forall : public Scheme { +class CTypesEqual : public Constraint { + + Type* A; + Type* B; + Node* Origin; + public: - TVSet* TVs; - ConstraintSet* Constraints; - class Type* Type; + CTypesEqual(Type* A, Type* B, Node* Origin): + Constraint(ConstraintKind::TypesEqual), A(A), B(B), Origin(Origin) {} - inline Forall(class Type* Type): - Scheme(SchemeKind::Forall), TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {} + Type* getLeft() const { + return A; + } - inline Forall( - TVSet* TVs, - ConstraintSet* Constraints, - class Type* Type - ): Scheme(SchemeKind::Forall), - TVs(TVs), - Constraints(Constraints), - Type(Type) {} + Type* getRight() const { + return B; + } - static bool classof(const Scheme* Scm) { - return Scm->getKind() == SchemeKind::Forall; + Node* getOrigin() const { + return Origin; } }; class TypeEnv { - std::unordered_map, Scheme*> Mapping; + TypeEnv* Parent; + + std::unordered_map, TypeScheme*> Mapping; public: - Scheme* lookup(ByteString Name, SymKind Kind) { - auto Key = std::make_tuple(Name, Kind); - auto Match = Mapping.find(Key); - if (Match == Mapping.end()) { - return nullptr; - } - return Match->second; - } + TypeEnv(TypeEnv* Parent = nullptr): + Parent(Parent) {} - void add(ByteString Name, Scheme* Scm, SymKind Kind) { - auto Key = std::make_tuple(Name, Kind); - ZEN_ASSERT(!Mapping.count(Key)) - // auto F = static_cast(Scm); - // std::cerr << Name << " : forall "; - // for (auto TV: *F->TVs) { - // std::cerr << describe(TV) << " "; - // } - // std::cerr << ". " << describe(F->Type) << "\n"; - Mapping.emplace(Key, Scm); - } + void add(ByteString Name, Type* Ty, SymbolKind Kind); + void add(ByteString Name, TypeScheme* Ty, SymbolKind Kind); + + bool hasVar(TVar* TV) const; + + TypeScheme* lookup(ByteString Name, SymbolKind Kind); }; - -enum class ConstraintKind { - Equal, - Field, - Many, - Empty, -}; - -class Constraint { - - const ConstraintKind Kind; - -public: - - inline Constraint(ConstraintKind Kind): - Kind(Kind) {} - - inline ConstraintKind getKind() const noexcept { - return Kind; - } - - Constraint* substitute(const TVSub& Sub); - - virtual ~Constraint() {} - -}; - -class CEqual : public Constraint { -public: - - Type* Left; - Type* Right; - Node* Source; - - inline CEqual(Type* Left, Type* Right, Node* Source = nullptr): - Constraint(ConstraintKind::Equal), Left(Left), Right(Right), Source(Source) {} - -}; - -class CField : public Constraint { -public: - - Type* TupleTy; - size_t I; - Type* FieldTy; - Node* Source; - - inline CField(Type* TupleTy, size_t I, Type* FieldTy, Node* Source = nullptr): - Constraint(ConstraintKind::Field), TupleTy(TupleTy), I(I), FieldTy(FieldTy), Source(Source) {} - -}; - -class CMany : public Constraint { -public: - - ConstraintSet& Elements; - - inline CMany(ConstraintSet& Elements): - Constraint(ConstraintKind::Many), Elements(Elements) {} - -}; - -class CEmpty : public Constraint { -public: - - inline CEmpty(): - Constraint(ConstraintKind::Empty) {} - -}; - -using InferContextFlagsMask = unsigned; - -class InferContext { -public: - - /** - * A heap-allocated list of type variables that eventually will become part of a Forall scheme. - */ - TVSet* TVs; - - /** - * A heap-allocated list of constraints that eventually will become part of a Forall scheme. - */ - ConstraintSet* Constraints; - - TypeEnv Env; - - Type* ReturnType = nullptr; - - InferContext* Parent = nullptr; - -}; +using ConstraintSet = std::vector; class Checker { - friend class Unifier; - friend class UnificationFrame; - - const LanguageConfig& Config; DiagnosticEngine& DE; - size_t NextConTypeId = 0; - size_t NextTypeVarId = 0; - - Type* BoolType; - Type* ListType; Type* IntType; + Type* BoolType; Type* StringType; - Type* UnitType; - - Graph RefGraph; - - std::unordered_map> InstanceMap; - - /// Inference context management - - InferContext* ActiveContext; - - InferContext& getContext(); - void setContext(InferContext* Ctx); - void popContext(); - - void makeEqual(Type* A, Type* B, Node* Source); - - void addConstraint(Constraint* Constraint); - - /** - * Get the return type for the current context. If none could be found, the - * program will abort. - */ - Type* getReturnType(); - - /// Type inference - - void forwardDeclare(Node* Node); - void forwardDeclareFunctionDeclaration(FunctionDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); - - Type* inferExpression(Expression* Expression); - Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true); - Type* inferLiteral(Literal* Lit); - Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); - - void infer(Node* node); - void inferFunctionDeclaration(FunctionDeclaration* N); - void inferConstraintExpression(ConstraintExpression* C); - - /// Factory methods - - Type* createConType(ByteString Name); - Type* createTypeVar(); - Type* createRigidVar(ByteString Name); - InferContext* createInferContext( - InferContext* Parent = nullptr, - TVSet* TVs = new TVSet, - ConstraintSet* Constraints = new ConstraintSet - ); - - /// Environment manipulation - - Scheme* lookup(ByteString Name, SymKind Kind); - - /** - * Looks up a type/variable and ensures that it is a monomorphic type. - * - * This method is mainly syntactic sugar to make it clear in the code when a - * monomorphic type is expected. - * - * Note that if the type is not monomorphic the program will abort with a - * stack trace. It wil **not** print a user-friendly error message. - * - * \returns If the type/variable could not be found `nullptr` is returned. - * Otherwise, a [Type] is returned. - */ - Type* lookupMono(ByteString Name, SymKind Kind); - - void addBinding(ByteString Name, Scheme* Scm, SymKind Kind); - - /// Constraint solving - - /** - * The queue that is used during solving to store any unsolved constraints. - */ - std::deque Queue; - - /** - * Unify two types, using `Source` as source location. - * - * \returns Whether a type variable was assigned a type or not. - */ - bool unify(Type* Left, Type* Right, Node* Source); - - void solve(Constraint* Constraint); - - /// Helpers - - void populate(SourceFile* SF); - - /** - * Verifies that type class signatures on type asserts in let-declarations - * correctly declare the right type classes. - */ - void checkTypeclassSigs(Node* N); - - Type* instantiate(Scheme* S, Node* Source); - - void initialize(Node* N); public: - Checker(const LanguageConfig& Config, DiagnosticEngine& DE); - - /** - * \internal - */ - Type* solveType(Type* Ty); - - void check(SourceFile* SF); - - inline Type* getBoolType() const { - return BoolType; + Checker(DiagnosticEngine& DE): + DE(DE) { + IntType = new TCon("Int"); + BoolType = new TCon("Bool"); + StringType = new TCon("String"); } - inline Type* getStringType() const { - return StringType; - } - - inline Type* getIntType() const { + Type* getIntType() const { return IntType; } - Type* getType(TypedNode* Node); + Type* getBoolType() const { + return BoolType; + } + + Type* getStringType() const { + return StringType; + } + + TVar* createTVar() { + return new TVar(); + } + + Type* instantiate(TypeScheme* Scm); + + void visitPattern(Pattern* P, Type* Ty, TypeEnv& Out); + + ConstraintSet inferSourceFile(TypeEnv& Env, SourceFile* SF); + + ConstraintSet inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaration* D); + + ConstraintSet inferVariableDeclaration(TypeEnv& Env, VariableDeclaration* Decl); + + ConstraintSet inferMany(TypeEnv& Env, std::vector& N, Type* RetTy); + + ConstraintSet inferElement(TypeEnv& Env, Node* N, Type* RetTy); + + std::tuple inferTypeExpr(TypeEnv& Env, TypeExpression* TE); + + std::tuple inferExpr(TypeEnv& Env, Expression* Expr, Type* RetTy); + + ConstraintSet checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected, Type* RetTy); + + void solve(const std::vector& Constraints); + + void unifyTypeType(Type* A, Type* B, Node* Source); + + void run(SourceFile* SF); + + Type* getTypeOfNode(Node* N); }; diff --git a/include/bolt/Common.hpp b/include/bolt/Common.hpp index cecd29e2c..e6ee8c33a 100644 --- a/include/bolt/Common.hpp +++ b/include/bolt/Common.hpp @@ -1,6 +1,8 @@ #pragma once +#include + #include "zen/config.hpp" namespace bolt { diff --git a/include/bolt/ConsolePrinter.hpp b/include/bolt/ConsolePrinter.hpp index e98bd6cef..d1fea0e94 100644 --- a/include/bolt/ConsolePrinter.hpp +++ b/include/bolt/ConsolePrinter.hpp @@ -5,13 +5,10 @@ #include "bolt/ByteString.hpp" #include "bolt/CST.hpp" -#include "bolt/Type.hpp" namespace bolt { class Node; -class Type; -class TypeclassSignature; class Diagnostic; enum class Color { @@ -160,12 +157,8 @@ class ConsolePrinter { void writePrefix(const Diagnostic& D); void writeBinding(const ByteString& Name); - void writeType(std::size_t I); - void writeType(const Type* Ty, const TypePath& Underline); - void writeType(const Type* Ty); void writeLoc(const TextFile& File, const TextLoc& Loc); - void writeTypeclassName(const ByteString& Name); - void writeTypeclassSignature(const TypeclassSignature& Sig); + void writeType(Type* Ty); void write(const std::string_view& S); void write(std::size_t N); diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index 56b1dd5fa..ddefc9d48 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -1,6 +1,7 @@ #pragma once +#include #include #include "bolt/ByteString.hpp" @@ -12,15 +13,15 @@ namespace bolt { enum class DiagnosticKind : unsigned char { BindingNotFound, - FieldNotFound, - InstanceNotFound, - InvalidTypeToTypeclass, - NotATuple, - TupleIndexOutOfRange, - TypeclassMissing, + // FieldNotFound, + // InstanceNotFound, + // InvalidTypeToTypeclass, + // NotATuple, + // TupleIndexOutOfRange, + // TypeclassMissing, UnexpectedString, UnexpectedToken, - UnificationError, + TypeMismatchError, }; class Diagnostic { @@ -33,7 +34,7 @@ protected: public: - inline DiagnosticKind getKind() const noexcept { + inline DiagnosticKind getKind() const { return Kind; } @@ -41,7 +42,7 @@ public: return nullptr; } - virtual unsigned getCode() const noexcept = 0; + virtual unsigned getCode() const = 0; virtual ~Diagnostic() {} @@ -57,7 +58,7 @@ public: inline UnexpectedStringDiagnostic(TextFile& File, TextLoc Location, String Actual): Diagnostic(DiagnosticKind::UnexpectedString), File(File), Location(Location), Actual(Actual) {} - unsigned getCode() const noexcept override { + unsigned getCode() const override { return 1001; } @@ -73,7 +74,7 @@ public: inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector Expected): Diagnostic(DiagnosticKind::UnexpectedToken), File(File), Actual(Actual), Expected(Expected) {} - unsigned getCode() const noexcept override { + unsigned getCode() const override { return 1101; } @@ -92,153 +93,28 @@ public: return Initiator; } - unsigned getCode() const noexcept override { + unsigned getCode() const override { return 2005; } }; -class UnificationErrorDiagnostic : public Diagnostic { +class TypeMismatchError : public Diagnostic { public: - Type* OrigLeft; - Type* OrigRight; - TypePath LeftPath; - TypePath RightPath; - Node* Source; + Type* Left; + Type* Right; + Node* N; - inline UnificationErrorDiagnostic(Type* OrigLeft, Type* OrigRight, TypePath LeftPath, TypePath RightPath, Node* Source): - Diagnostic(DiagnosticKind::UnificationError), OrigLeft(OrigLeft), OrigRight(OrigRight), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {} - - inline Type* getLeft() const { - return OrigLeft->resolve(LeftPath); - } - - inline Type* getRight() const { - return OrigRight->resolve(RightPath); - } + inline TypeMismatchError(Type* Left, Type* Right, Node* N): + Diagnostic(DiagnosticKind::TypeMismatchError), Left(Left), Right(Right), N(N) {} inline Node* getNode() const override { - return Source; + return N; } - unsigned getCode() const noexcept override { - return 2010; - } - -}; - -class TypeclassMissingDiagnostic : public Diagnostic { -public: - - TypeclassSignature Sig; - Node* Decl; - - inline TypeclassMissingDiagnostic(TypeclassSignature Sig, Node* Decl): - Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {} - - inline Node* getNode() const override { - return Decl; - } - - unsigned getCode() const noexcept override { - return 2201; - } - -}; - -class InstanceNotFoundDiagnostic : public Diagnostic { -public: - - ByteString TypeclassName; - Type* Ty; - Node* Source; - - inline InstanceNotFoundDiagnostic(ByteString TypeclassName, Type* Ty, Node* Source): - Diagnostic(DiagnosticKind::InstanceNotFound), TypeclassName(TypeclassName), Ty(Ty), Source(Source) {} - - inline Node* getNode() const override { - return Source; - } - - unsigned getCode() const noexcept override { - return 2251; - } - -}; - -class TupleIndexOutOfRangeDiagnostic : public Diagnostic { -public: - - Type* Tuple; - std::size_t I; - Node* Source; - - inline TupleIndexOutOfRangeDiagnostic(Type* Tuple, std::size_t I, Node* Source): - Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I), Source(Source) {} - - inline Node * getNode() const override { - return Source; - } - - unsigned getCode() const noexcept override { - return 2015; - } - -}; - -class InvalidTypeToTypeclassDiagnostic : public Diagnostic { -public: - - Type* Actual; - std::vector Classes; - Node* Source; - - inline InvalidTypeToTypeclassDiagnostic(Type* Actual, std::vector Classes, Node* Source): - Diagnostic(DiagnosticKind::InvalidTypeToTypeclass), Actual(Actual), Classes(Classes), Source(Source) {} - - inline Node* getNode() const override { - return Source; - } - - unsigned getCode() const noexcept override { - return 2060; - } - -}; - -class FieldNotFoundDiagnostic : public Diagnostic { -public: - - ByteString Name; - Type* Ty; - TypePath Path; - Node* Source; - - inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source): - Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {} - - unsigned getCode() const noexcept override { - return 2017; - } - -}; - -class NotATupleDiagnostic : public Diagnostic { -public: - - Type* Ty; - Node* Source; - - inline NotATupleDiagnostic(Type* Ty, Node* Source): - Diagnostic(DiagnosticKind::NotATuple), Ty(Ty), Source(Source) {} - - inline Node * getNode() const override { - return Source; - } - - unsigned getCode() const noexcept override { - return 2016; + unsigned getCode() const override { + return 3001; } }; diff --git a/include/bolt/Either.hpp b/include/bolt/Either.hpp new file mode 100644 index 000000000..bd4e375a1 --- /dev/null +++ b/include/bolt/Either.hpp @@ -0,0 +1,124 @@ + +#pragma once + +#include +#include +#include +#include + +#include "bolt/Common.hpp" + +namespace bolt { + +template +concept ErrorLike = requires (T a) { + { message(a) } -> std::convertible_to; +}; + +template +struct Left { + T value; +}; + +template +struct Right { + T value; +}; + +template +class Either { + + bool _is_left; + + union { + L _left; + R _right; + }; + +public: + + template + Either(const Left& left): + _is_left(true), _left(left.value) {} + + template + Either(const Right& right): + _is_left(false), _right(right.value) {} + + template + Either(Left&& left): + _is_left(true), _left(std::move(left.value)) {} + + template + Either(Right&& right): + _is_left(false), _right(std::move(right.value)) {} + + Either(const Either& other): + _is_left(_is_left) { + if (other._is_left) { + new (&_left)L(other._left); + } else { + new (&_right)L(other._right); + } + } + + Either(Either&& other): + _is_left(std::move(other._is_left)) { + if (_is_left) { + new (&_left)L(std::move(other._left)); + } else { + new (&_right)L(std::move(other._right)); + } + } + + bool is_left() const { + return _is_left; + } + + auto left() const { + return _left; + } + + auto right() const { + return _right; + } + + R&& unwrap() requires ErrorLike { + if (_is_left) { + auto desc = message(_left); + ZEN_PANIC("trying to unwrap a result containing an error: %s", desc.c_str()); + } + return std::move(_right); + } + + ~Either() { + if (_is_left) { + _left.~L(); + } else { + _right.~R(); + } + } + +}; + +// template +// auto left(const L& value) { +// return Left { value }; +// } + +template +auto left(L&& value) { + return Left { std::move(value) }; +} + +// template +// auto right(const R& value) { +// return Right { value }; +// } + +template +auto right(R&& value) { + return Right { std::move(value) }; +} + +} diff --git a/include/bolt/Program.hpp b/include/bolt/Program.hpp index f1a29fcd6..549453729 100644 --- a/include/bolt/Program.hpp +++ b/include/bolt/Program.hpp @@ -5,6 +5,8 @@ #include #include +#include "zen/range.hpp" + #include "bolt/Common.hpp" #include "bolt/Checker.hpp" #include "bolt/DiagnosticEngine.hpp" @@ -47,12 +49,12 @@ public: if (Match != TCs.end()) { return Match->second; } - return TCs.emplace(SF, Checker { Config, DE }).first->second; + return TCs.emplace(SF, Checker { DE }).first->second; } void check() { for (auto SF: getSourceFiles()) { - getTypeChecker(SF).check(SF); + getTypeChecker(SF).run(SF); } } diff --git a/include/bolt/Type.hpp b/include/bolt/Type.hpp index e51ac3eb4..8c6152b35 100644 --- a/include/bolt/Type.hpp +++ b/include/bolt/Type.hpp @@ -1,61 +1,32 @@ #pragma once -#include -#include -#include -#include -#include +#include +#include #include +#include #include "zen/config.hpp" -#include "bolt/CST.hpp" #include "bolt/ByteString.hpp" namespace bolt { -class Type; -class TCon; - -using TypeclassId = ByteString; - -using TypeclassContext = std::unordered_set; - -struct TypeclassSignature { - - using TypeclassId = ByteString; - TypeclassId Id; - std::vector Params; - - bool operator<(const TypeclassSignature& Other) const; - bool operator==(const TypeclassSignature& Other) const; - -}; - -struct TypeSig { - Type* Orig; - Type* Op; - std::vector Args; -}; - enum class TypeIndexKind { - AppOpType, - AppArgType, - ArrowParamType, - ArrowReturnType, + AppOp, + AppArg, + ArrowLeft, + ArrowRight, TupleElement, - FieldType, - FieldRestType, - PresentType, + FieldElement, + FieldRest, + PresentElement, End, }; class TypeIndex { -protected: friend class Type; - friend class TypeIterator; TypeIndexKind Kind; @@ -71,685 +42,202 @@ protected: public: - bool operator==(const TypeIndex& Other) const noexcept; - - void advance(const Type* Ty); - - static TypeIndex forFieldType() { - return { TypeIndexKind::FieldType }; + static TypeIndex forAppOp() { + return { TypeIndexKind::AppOp }; } - static TypeIndex forFieldRest() { - return { TypeIndexKind::FieldRestType }; + static TypeIndex forAppArg() { + return { TypeIndexKind::AppArg }; } - static TypeIndex forArrowParamType() { - return { TypeIndexKind::ArrowParamType }; + static TypeIndex forArrowLeft() { + return { TypeIndexKind::ArrowLeft }; } - static TypeIndex forArrowReturnType() { - return { TypeIndexKind::ArrowReturnType }; + static TypeIndex forArrowRight() { + return { TypeIndexKind::ArrowRight }; } - static TypeIndex forTupleElement(std::size_t I) { + static TypeIndex forTupleIndex(std::size_t I) { return { TypeIndexKind::TupleElement, I }; } - static TypeIndex forAppOpType() { - return { TypeIndexKind::AppOpType }; - } - - static TypeIndex forAppArgType() { - return { TypeIndexKind::AppArgType }; - } - - static TypeIndex forPresentType() { - return { TypeIndexKind::PresentType }; - } - -}; - -class TypeIterator { - - friend class Type; - - Type* Ty; - TypeIndex Index; - - TypeIterator(Type* Ty, TypeIndex Index): - Ty(Ty), Index(Index) {} - -public: - - TypeIterator& operator++() noexcept { - Index.advance(Ty); - return *this; - } - - bool operator==(const TypeIterator& Other) const noexcept { - return Ty == Other.Ty && Index == Other.Index; - } - - Type* operator*() { - return Ty; - } - - TypeIndex getIndex() const noexcept { - return Index; - } - }; using TypePath = std::vector; -using TVSub = std::unordered_map; -using TVSet = std::unordered_set; - -enum class TypeKind : unsigned char { +enum class TypeKind { Var, Con, + Fun, App, - Arrow, - Tuple, - Field, - Nil, - Absent, - Present, }; -class Type; +class TVar; +class TCon; +class TFun; +class TApp; -struct TCon { - size_t Id; - ByteString DisplayName; +class Type { +protected: - bool operator==(const TCon& Other) const; + TypeKind TK; -}; + Type(TypeKind TK): + TK(TK) {} -struct TApp { - Type* Op; - Type* Arg; +public: - bool operator==(const TApp& Other) const; - -}; - -enum class VarKind { - Rigid, - Unification, -}; - -struct TVar { - VarKind VK; - size_t Id; - TypeclassContext Context; - std::optional Name; - std::optional Provided; - - VarKind getKind() const { - return VK; + virtual Type* find() const { + return const_cast(this); } - bool isUni() const { - return VK == VarKind::Unification; + inline TypeKind getKind() const { + return TK; } - bool isRigid() const { - return VK == VarKind::Rigid; - } - - bool operator==(const TVar& Other) const; - -}; - -struct TArrow { - Type* ParamType; - Type* ReturnType; - - bool operator==(const TArrow& Other) const; - -}; - -struct TTuple { - std::vector ElementTypes; - - bool operator==(const TTuple& Other) const; - -}; - -struct TNil { - bool operator==(const TNil& Other) const; -}; - -struct TField { - ByteString Name; - Type* Ty; - Type* RestTy; - bool operator==(const TField& Other) const; -}; - -struct TAbsent { - bool operator==(const TAbsent& Other) const; -}; - -struct TPresent { - Type* Ty; - bool operator==(const TPresent& Other) const; -}; - -struct Type { - - TypeKind Kind; - - Type* Parent = this; - - union { - TCon Con; - TApp App; - TVar Var; - TArrow Arrow; - TTuple Tuple; - TNil Nil; - TField Field; - TAbsent Absent; - TPresent Present; - }; - - Type(TCon&& Con): - Kind(TypeKind::Con), Con(std::move(Con)) {}; - - Type(TApp&& App): - Kind(TypeKind::App), App(std::move(App)) {}; - - Type(TVar&& Var): - Kind(TypeKind::Var), Var(std::move(Var)) {}; - - Type(TArrow&& Arrow): - Kind(TypeKind::Arrow), Arrow(std::move(Arrow)) {}; - - Type(TTuple&& Tuple): - Kind(TypeKind::Tuple), Tuple(std::move(Tuple)) {}; - - Type(TNil&& Nil): - Kind(TypeKind::Nil), Nil(std::move(Nil)) {}; - - Type(TField&& Field): - Kind(TypeKind::Field), Field(std::move(Field)) {}; - - Type(TAbsent&& Absent): - Kind(TypeKind::Absent), Absent(std::move(Absent)) {}; - - Type(TPresent&& Present): - Kind(TypeKind::Present), Present(std::move(Present)) {}; - - Type(const Type& Other): Kind(Other.Kind) { - switch (Kind) { - case TypeKind::Con: - new (&Con)TCon(Other.Con); - break; - case TypeKind::App: - new (&App)TApp(Other.App); - break; - case TypeKind::Var: - new (&Var)TVar(Other.Var); - break; - case TypeKind::Arrow: - new (&Arrow)TArrow(Other.Arrow); - break; - case TypeKind::Tuple: - new (&Tuple)TTuple(Other.Tuple); - break; - case TypeKind::Nil: - new (&Nil)TNil(Other.Nil); - break; - case TypeKind::Field: - new (&Field)TField(Other.Field); - break; - case TypeKind::Absent: - new (&Absent)TAbsent(Other.Absent); - break; - case TypeKind::Present: - new (&Present)TPresent(Other.Present); - break; - } - } - - Type(Type&& Other): Kind(std::move(Other.Kind)) { - switch (Kind) { - case TypeKind::Con: - new (&Con)TCon(std::move(Other.Con)); - break; - case TypeKind::App: - new (&App)TApp(std::move(Other.App)); - break; - case TypeKind::Var: - new (&Var)TVar(std::move(Other.Var)); - break; - case TypeKind::Arrow: - new (&Arrow)TArrow(std::move(Other.Arrow)); - break; - case TypeKind::Tuple: - new (&Tuple)TTuple(std::move(Other.Tuple)); - break; - case TypeKind::Nil: - new (&Nil)TNil(std::move(Other.Nil)); - break; - case TypeKind::Field: - new (&Field)TField(std::move(Other.Field)); - break; - case TypeKind::Absent: - new (&Absent)TAbsent(std::move(Other.Absent)); - break; - case TypeKind::Present: - new (&Present)TPresent(std::move(Other.Present)); - break; - } - } - - TypeKind getKind() const { - return Kind; - } - - bool isVarRigid() const { - return Kind == TypeKind::Var - && asVar().getKind() == VarKind::Rigid; - } - - bool isVar() const { - return Kind == TypeKind::Var; - } - - TVar& asVar() { - ZEN_ASSERT(Kind == TypeKind::Var); - return Var; - } - - const TVar& asVar() const { - ZEN_ASSERT(Kind == TypeKind::Var); - return Var; - } - - bool isApp() const { - return Kind == TypeKind::App; - } - - TApp& asApp() { - ZEN_ASSERT(Kind == TypeKind::App); - return App; - } - - const TApp& asApp() const { - ZEN_ASSERT(Kind == TypeKind::App); - return App; - } - - bool isCon() const { - return Kind == TypeKind::Con; - } - - TCon& asCon() { - ZEN_ASSERT(Kind == TypeKind::Con); - return Con; - } - - const TCon& asCon() const { - ZEN_ASSERT(Kind == TypeKind::Con); - return Con; - } - - bool isArrow() const { - return Kind == TypeKind::Arrow; - } - - TArrow& asArrow() { - ZEN_ASSERT(Kind == TypeKind::Arrow); - return Arrow; - } - - const TArrow& asArrow() const { - ZEN_ASSERT(Kind == TypeKind::Arrow); - return Arrow; - } - - bool isTuple() const { - return Kind == TypeKind::Tuple; - } - - TTuple& asTuple() { - ZEN_ASSERT(Kind == TypeKind::Tuple); - return Tuple; - } - - const TTuple& asTuple() const { - ZEN_ASSERT(Kind == TypeKind::Tuple); - return Tuple; - } - - bool isField() const { - return Kind == TypeKind::Field; - } - - TField& asField() { - ZEN_ASSERT(Kind == TypeKind::Field); - return Field; - } - - const TField& asField() const { - ZEN_ASSERT(Kind == TypeKind::Field); - return Field; - } - - bool isAbsent() const { - return Kind == TypeKind::Absent; - } - - TAbsent& asAbsent() { - ZEN_ASSERT(Kind == TypeKind::Absent); - return Absent; - } - const TAbsent& asAbsent() const { - ZEN_ASSERT(Kind == TypeKind::Absent); - return Absent; - } - - bool isPresent() const { - return Kind == TypeKind::Present; - } - - TPresent& asPresent() { - ZEN_ASSERT(Kind == TypeKind::Present); - return Present; - } - const TPresent& asPresent() const { - ZEN_ASSERT(Kind == TypeKind::Present); - return Present; - } - - bool isNil() const { - return Kind == TypeKind::Nil; - } - - TNil& asNil() { - ZEN_ASSERT(Kind == TypeKind::Nil); - return Nil; - } - const TNil& asNil() const { - ZEN_ASSERT(Kind == TypeKind::Nil); - return Nil; - } - - Type* rewrite(std::function Fn, bool Recursive = true); - - Type* resolve(const TypeIndex& Index) const noexcept; - - Type* resolve(const TypePath& Path) noexcept { - Type* Ty = this; - for (auto El: Path) { - Ty = Ty->resolve(El); - } - return Ty; - } - - void set(Type* Ty) { - auto Root = find(); - // It is not possible to set a solution twice. - if (isVar()) { - ZEN_ASSERT(Root->isVar()); - } - Root->Parent = Ty; - } - - Type* find() const { - Type* Curr = const_cast(this); - for (;;) { - auto Keep = Curr->Parent; - if (Keep == Curr) { - return Keep; - } - Curr->Parent = Keep->Parent; - Curr = Keep; - } + bool isVar() const { + return TK == TypeKind::Var; } bool operator==(const Type& Other) const; - void destroy() { - switch (Kind) { - case TypeKind::Con: - App.~TApp(); - break; - case TypeKind::App: - App.~TApp(); - break; - case TypeKind::Var: - Var.~TVar(); - break; - case TypeKind::Arrow: - Arrow.~TArrow(); - break; - case TypeKind::Tuple: - Tuple.~TTuple(); - break; - case TypeKind::Nil: - Nil.~TNil(); - break; - case TypeKind::Field: - Field.~TField(); - break; - case TypeKind::Absent: - Absent.~TAbsent(); - break; - case TypeKind::Present: - Present.~TPresent(); - break; - } - } + std::string toString() const; - Type& operator=(Type& Other) { - destroy(); - Kind = Other.Kind; - switch (Kind) { - case TypeKind::Con: - App = Other.App; - break; - case TypeKind::App: - App = Other.App; - break; - case TypeKind::Var: - Var = Other.Var; - break; - case TypeKind::Arrow: - Arrow = Other.Arrow; - break; - case TypeKind::Tuple: - Tuple = Other.Tuple; - break; - case TypeKind::Nil: - Nil = Other.Nil; - break; - case TypeKind::Field: - Field = Other.Field; - break; - case TypeKind::Absent: - Absent = Other.Absent; - break; - case TypeKind::Present: - Present = Other.Present; - break; - } - return *this; - } + Type* resolve(const TypePath& P); - bool hasTypeVar(Type* TV) const; + TVar* asVar(); + const TVar* asVar() const; - TypeIterator begin(); - TypeIterator end(); + TFun* asFun(); + const TFun* asFun() const; - TypeIndex getStartIndex() const; - TypeIndex getEndIndex() const; - - Type* substitute(const TVSub& Sub); - - void visitEachChild(std::function Proc); - - TVSet getTypeVars(); - - ~Type() { - destroy(); - } - - static Type* buildArrow(std::vector ParamTypes, Type* ReturnType) { - Type* Curr = ReturnType; - for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) { - Curr = new Type(TArrow(*Iter, Curr)); - } - return Curr; - } + TCon* asCon(); + const TCon* asCon() const; }; -template -class TypeVisitorBase { -protected: +class TVar : public Type { - template - using C = std::conditional::type; - - virtual void enterType(C* Ty) {} - virtual void exitType(C* Ty) {} - - // virtual void visitType(C* Ty) { - // visitEachChild(Ty); - // } - - virtual void visitVarType(C& Ty) { - } - - virtual void visitAppType(C& Ty) { - visit(Ty.Op); - visit(Ty.Arg); - } - - virtual void visitPresentType(C& Ty) { - visit(Ty.Ty); - } - - virtual void visitConType(C& Ty) { - } - - virtual void visitArrowType(C& Ty) { - visit(Ty.ParamType); - visit(Ty.ReturnType); - } - - virtual void visitTupleType(C& Ty) { - for (auto ElTy: Ty.ElementTypes) { - visit(ElTy); - } - } - - virtual void visitAbsentType(C& Ty) { - } - - virtual void visitFieldType(C& Ty) { - visit(Ty.Ty); - visit(Ty.RestTy); - } - - virtual void visitNilType(C& Ty) { - } + Type* Parent = this; public: - void visitEachChild(C* Ty) { - switch (Ty->getKind()) { - case TypeKind::Var: - case TypeKind::Absent: - case TypeKind::Nil: - case TypeKind::Con: - break; - case TypeKind::Arrow: - { - auto& Arrow = Ty->asArrow(); - visit(Arrow->ParamType); - visit(Arrow->ReturnType); - break; - } - case TypeKind::Tuple: - { - auto& Tuple = Ty->asTuple(); - for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) { - visit(Tuple->ElementTypes[I]); - } - break; - } - case TypeKind::App: - { - auto& App = Ty->asApp(); - visit(App->Op); - visit(App->Arg); - break; - } - case TypeKind::Field: - { - auto& Field = Ty->asField(); - visit(Field->Ty); - visit(Field->RestTy); - break; - } - case TypeKind::Present: - { - auto& Present = Ty->asPresent(); - visit(Present->Ty); - break; - } - } + TVar(): + Type(TypeKind::Var) {} + + void set(Type* Ty) { + auto Root = find(); + // It is not possible to set a solution twice. + ZEN_ASSERT(Root->isVar()); + static_cast(Root)->Parent = Ty; } - void visit(C* Ty) { - - // Always look at the most solved solution - Ty = Ty->find(); - - enterType(Ty); - switch (Ty->getKind()) { - case TypeKind::Present: - visitPresentType(Ty->asPresent()); - break; - case TypeKind::Absent: - visitAbsentType(Ty->asAbsent()); - break; - case TypeKind::Nil: - visitNilType(Ty->asNil()); - break; - case TypeKind::Field: - visitFieldType(Ty->asField()); - break; - case TypeKind::Con: - visitConType(Ty->asCon()); - break; - case TypeKind::Arrow: - visitArrowType(Ty->asArrow()); - break; - case TypeKind::Var: - visitVarType(Ty->asVar()); - break; - case TypeKind::Tuple: - visitTupleType(Ty->asTuple()); - break; - case TypeKind::App: - visitAppType(Ty->asApp()); - break; + Type* find() const override { + TVar* Curr = const_cast(this); + for (;;) { + auto Keep = Curr->Parent; + if (Keep == Curr || !Keep->isVar()) { + return Keep; + } + auto Keep2 = static_cast(Keep); + Curr->Parent = Keep2->Parent; + Curr = Keep2; } - exitType(Ty); } - virtual ~TypeVisitorBase() {} - }; -using TypeVisitor = TypeVisitorBase; -using ConstTypeVisitor = TypeVisitorBase; +class TCon : public Type { + + ByteString Name; + +public: + + TCon(ByteString Name): + Type(TypeKind::Con), Name(Name) {} + + ByteStringView getName() const { + return Name; + } + +}; + +class TFun : public Type { + + Type* Left; + Type* Right; + +public: + + TFun(Type* Left, Type* Right): + Type(TypeKind::Fun), Left(Left), Right(Right) {} + + Type* getLeft() const { + return Left; + } + + Type* getRight() const { + return Right; + } + +}; + +class TApp : public Type { + + Type* Left; + Type* Right; + +public: + + TApp(Type* Left, Type* Right): + Type(TypeKind::App), Left(Left), Right(Right) {} + + Type* getLeft() const { + return Left; + } + + Type* getRight() const { + return Right; + } + +}; + +struct TypeScheme { + + std::unordered_set Unbound; + Type* Ty; + + Type* getType() const { + return Ty; + } + +}; + +class TypeVisitor { +public: + + void visit(Type* Ty); + + virtual void visitVar(TVar* TV) { + + } + + virtual void visitApp(TApp* App) { + visit(App->getLeft()); + visit(App->getRight()); + } + + virtual void visitCon(TCon* Con) { + + } + + virtual void visitFun(TFun* Fun) { + visit(Fun->getLeft()); + visit(Fun->getRight()); + } + +}; } + diff --git a/src/CST.cc b/src/CST.cc index 0aafe4569..b4b2d5f43 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -1,6 +1,4 @@ -#include "zen/config.hpp" - #include "bolt/CST.hpp" #include "bolt/CSTVisitor.hpp" @@ -512,8 +510,8 @@ Token* ReturnStatement::getFirstToken() const { } Token* ReturnStatement::getLastToken() const { - if (Expression) { - return Expression->getLastToken(); + if (E) { + return E->getLastToken(); } return ReturnKeyword; } @@ -1036,5 +1034,12 @@ SymbolPath ReferenceExpression::getSymbolPath() const { return SymbolPath { ModuleNames, Name.getCanonicalText() }; } +bool TypedNode::classof(Node* N) { + return Expression::classof(N) + || TypeExpression::classof(N) + || FunctionDeclaration::classof(N) + || VariableDeclaration::classof(N); +} + } diff --git a/src/Checker.cc b/src/Checker.cc index 72344ee07..591c540cc 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -1,740 +1,436 @@ -#include -#include -#include - -#include "zen/config.hpp" - -#include "bolt/Type.hpp" #include "bolt/CSTVisitor.hpp" -#include "bolt/DiagnosticEngine.hpp" -#include "bolt/Diagnostics.hpp" +#include "zen/graph.hpp" + +#include "bolt/ByteString.hpp" #include "bolt/CST.hpp" +#include "bolt/Type.hpp" +#include "bolt/Diagnostics.hpp" +#include +#include #include "bolt/Checker.hpp" namespace bolt { -Constraint* Constraint::substitute(const TVSub &Sub) { - switch (Kind) { - case ConstraintKind::Equal: - { - auto Equal = static_cast(this); - return new CEqual(Equal->Left->substitute(Sub), Equal->Right->substitute(Sub), Equal->Source); - } - case ConstraintKind::Many: - { - auto Many = static_cast(this); - auto NewConstraints = new ConstraintSet(); - for (auto Element: Many->Elements) { - NewConstraints->push_back(Element->substitute(Sub)); - } - return new CMany(*NewConstraints); - } - case ConstraintKind::Field: - { - auto Field = static_cast(this); - auto NewTupleTy = Field->TupleTy->substitute(Sub); - auto NewFieldTy = Field->FieldTy->substitute(Sub); - return new CField(NewTupleTy, Field->I, NewFieldTy, Field->Source); - } - case ConstraintKind::Empty: - return this; +static inline void mergeTo(ConstraintSet& Out, const ConstraintSet& Other) { + for (auto C: Other) { + Out.push_back(C); } - ZEN_UNREACHABLE } -Type* Checker::solveType(Type* Ty) { - return Ty->rewrite([this](auto Ty) { return Ty->find(); }, true); -} - -Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): - Config(Config), DE(DE) { - BoolType = createConType("Bool"); - IntType = createConType("Int"); - StringType = createConType("String"); - ListType = createConType("List"); - UnitType = new Type(TTuple({})); - } - -Scheme* Checker::lookup(ByteString Name, SymKind Kind) { - auto Curr = &getContext(); - for (;;) { - auto Match = Curr->Env.lookup(Name, Kind); - if (Match != nullptr) { - return Match; +TypeScheme* TypeEnv::lookup(ByteString Name, SymbolKind Kind) { + auto Curr = this; + do { + auto Match = Curr->Mapping.find(std::make_tuple(Name, Kind)); + if (Match != Curr->Mapping.end()) { + return Match->second; } Curr = Curr->Parent; - if (!Curr) { - break; - } - } + } while (Curr); return nullptr; } -Type* Checker::lookupMono(ByteString Name, SymKind Kind) { - auto Scm = lookup(Name, Kind); - if (Scm == nullptr) { - return nullptr; +void TypeEnv::add(ByteString Name, TypeScheme* Scm, SymbolKind Kind) { + Mapping.emplace(std::make_tuple(Name, Kind), Scm); +} + +void TypeEnv::add(ByteString Name, Type* Ty, SymbolKind Kind) { + add(Name, new TypeScheme { {}, Ty }, Kind); +} + +using TVSub = std::unordered_map; + +Type* substituteType(Type* Ty, const TVSub& Sub) { + switch (Ty->getKind()) { + case TypeKind::App: + { + auto A = static_cast(Ty); + auto NewLeft = substituteType(A->getLeft(), Sub); + auto NewRight = substituteType(A->getRight(), Sub); + if (A->getLeft() == NewLeft && A->getRight() == NewRight) { + return Ty; + } + return new TApp(NewLeft, NewRight); + } + case TypeKind::Con: + return Ty; + case TypeKind::Var: + { + auto NewTy = Ty->find(); + if (NewTy->getKind() != TypeKind::Var) { + return substituteType(NewTy, Sub); + } + auto Match = Sub.find(static_cast(NewTy)); + return Match == Sub.end() + ? NewTy + : Match->second; + } + case TypeKind::Fun: + { + auto F = static_cast(Ty); + auto NewLeft = substituteType(F->getLeft(), Sub); + auto NewRight = substituteType(F->getRight(), Sub); + if (F->getLeft() == NewLeft && F->getRight() == NewRight) { + return Ty; + } + return new TFun(NewLeft, NewRight); + } } - auto F = static_cast(Scm); - ZEN_ASSERT(F->TVs == nullptr || F->TVs->empty()); - return F->Type; } -void Checker::addBinding(ByteString Name, Scheme* Scm, SymKind Kind) { - getContext().Env.add(Name, Scm, Kind); + + +Type* Checker::instantiate(TypeScheme* Scm) { + TVSub Sub; + for (auto TV: Scm->Unbound) { + auto Fresh = createTVar(); + Sub[TV] = Fresh; + } + return substituteType(Scm->getType(), Sub); } -Type* Checker::getReturnType() { - auto Ty = getContext().ReturnType; - ZEN_ASSERT(Ty != nullptr); - return Ty; -} +std::tuple Checker::inferExpr(TypeEnv& Env, Expression* Expr, Type* RetTy) { -static bool hasTypeVar(TVSet& Set, Type* Type) { - for (auto TV: Type->getTypeVars()) { - if (Set.count(TV)) { - return true; + ConstraintSet Out; + Type* Ty; + + for (auto Ann: Expr->Annotations) { + if (Ann->getKind() == NodeKind::TypeAssertAnnotation) { + auto [AnnOut, AnnTy] = inferTypeExpr(Env, static_cast(Ann)->getTypeExpression()); + mergeTo(Out, AnnOut); } } - return false; -} -void Checker::setContext(InferContext* Ctx) { - ActiveContext = Ctx; -} + switch (Expr->getKind()) { -void Checker::popContext() { - ZEN_ASSERT(ActiveContext); - ActiveContext = ActiveContext->Parent; -} - -InferContext& Checker::getContext() { - ZEN_ASSERT(ActiveContext); - return *ActiveContext; -} - -void Checker::makeEqual(Type* A, Type* B, Node* Source) { - addConstraint(new CEqual(A, B, Source)); -} - -void Checker::addConstraint(Constraint* C) { - - switch (C->getKind()) { - - case ConstraintKind::Field: - // FIXME Check if this is all that needs to be done - getContext().Constraints->push_back(C); - break; - - case ConstraintKind::Equal: - { - auto Y = static_cast(C); - - // This will store all inference contexts in Contexts, from most local - // one to most general one. Because this order is not ideal, the code - // below will have to handle that. - auto Curr = &getContext(); - std::vector Contexts; - for (;;) { - Contexts.push_back(Curr); - Curr = Curr->Parent; - if (!Curr) { - break; + case NodeKind::ReferenceExpression: + { + auto E = static_cast(Expr); + auto Name = E->Name.getCanonicalText(); + auto Match = Env.lookup(Name, SymbolKind::Var); + if (Match == nullptr) { + DE.add(Name, E->Name); + Ty = createTVar(); + } else { + Ty = instantiate(Match); } + break; } - std::size_t Global = Contexts.size()-1; - - // If no MaxLevelLeft was found, that means that not a single - // corresponding type variable was found in the contexts. We set it to - // Contexts.size()-1, which corresponds to the global inference context. - std::size_t MaxLevelLeft = Global; - for (std::size_t I = 0; I < Global; I++) { - auto Ctx = Contexts[I]; - if (hasTypeVar(*Ctx->TVs, Y->Left)) { - MaxLevelLeft = I; - break; - } - } - - // Same as above but now mirrored for Y->Right - std::size_t MaxLevelRight = Global; - for (std::size_t I = 0; I < Global; I++) { - auto Ctx = Contexts[I]; - if (hasTypeVar(*Ctx->TVs, Y->Right)) { - MaxLevelRight = I; - break; - } - } - - // The lowest index is determined by the one that has no type variables - // in Y->Left AND in Y->Right. This implies max() must be used, so that - // the very first enounter of a type variable matters. - auto UpperLevel = std::max(MaxLevelLeft, MaxLevelRight); - - // Now find the lowest index LowerLevel such that all the contexts that are more - // local do not contain any type variables that are present in the - // equality constraint. - std::size_t LowerLevel = UpperLevel; - for (std::size_t I = Global; I-- > 0; ) { - auto Ctx = Contexts[I]; - if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { - LowerLevel = I; - break; - } - } - - if (UpperLevel == LowerLevel || MaxLevelLeft == Global || MaxLevelRight == Global) { - unify(Y->Left, Y->Right, Y->Source); - } else { - Contexts[UpperLevel]->Constraints->push_back(C); - } - - break; - } - - case ConstraintKind::Many: - { - auto Y = static_cast(C); - for (auto Element: Y->Elements) { - addConstraint(Element); - } - break; - } - - case ConstraintKind::Empty: - break; - - } - -} - -void Checker::forwardDeclare(Node* X) { - - switch (X->getKind()) { - - case NodeKind::ExpressionStatement: - case NodeKind::ReturnStatement: - case NodeKind::IfStatement: - break; - - case NodeKind::SourceFile: - { - auto File = static_cast(X); - for (auto Element: File->Elements) { - forwardDeclare(Element) ; - } - break; - } - - case NodeKind::ClassDeclaration: - { - auto Class = static_cast(X); - // for (auto TE: Class->TypeVars) { - // auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText()); - // // TV->Contexts.emplace(Class->Name->getCanonicalText()); - // TE->setType(TV); - // } - for (auto Element: Class->Elements) { - forwardDeclare(Element); - } - break; - } - - case NodeKind::InstanceDeclaration: - { - auto Decl = static_cast(X); - - // Needed to set the associated Type on the CST node - for (auto TE: Decl->TypeExps) { - inferTypeExpression(TE); - } - - auto Match = InstanceMap.find(Decl->Name->getCanonicalText()); - if (Match == InstanceMap.end()) { - InstanceMap.emplace(Decl->Name->getCanonicalText(), std::vector { Decl }); - } else { - Match->second.push_back(Decl); - } - - for (auto Element: Decl->Elements) { - forwardDeclare(Element); - } - - break; - } - - case NodeKind::PrefixFunctionDeclaration: - case NodeKind::InfixFunctionDeclaration: - case NodeKind::SuffixFunctionDeclaration: - case NodeKind::NamedFunctionDeclaration: - break; - - case NodeKind::VariableDeclaration: - { - auto Decl = static_cast(X); - Type* Ty; - if (Decl->TypeAssert) { - Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression); - } else { - Ty = createTypeVar(); - } - Decl->setType(Ty); - break; - } - - case NodeKind::VariantDeclaration: - { - auto Decl = static_cast(X); - - setContext(Decl->Ctx); - - std::vector Vars; - for (auto TE: Decl->TVs) { - auto TV = createRigidVar(TE->Name->getCanonicalText()); - Decl->Ctx->TVs->emplace(TV); - Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type); - Vars.push_back(TV); - } - - Type* Ty = createConType(Decl->Name->getCanonicalText()); - - // Build the type that is actually returned by constructor functions - auto RetTy = Ty; - for (auto Var: Vars) { - RetTy = new Type(TApp(RetTy, Var)); - } - - // Must be added early so we can create recursive types - Decl->Ctx->Parent->Env.add(Decl->Name->getCanonicalText(), new Forall(Ty), SymKind::Type); - - for (auto Member: Decl->Members) { - switch (Member->getKind()) { - case NodeKind::TupleVariantDeclarationMember: - { - auto TupleMember = static_cast(Member); - std::vector ParamTypes; - for (auto Element: TupleMember->Elements) { - // inferTypeExpression will look up any TVars that were part of the signature of Decl - ParamTypes.push_back(inferTypeExpression(Element, false)); - } - Decl->Ctx->Parent->Env.add( - TupleMember->Name->getCanonicalText(), - new Forall( - Decl->Ctx->TVs, - Decl->Ctx->Constraints, - Type::buildArrow(ParamTypes, RetTy) - ), - SymKind::Var - ); + case NodeKind::LiteralExpression: + { + auto E = static_cast(Expr); + switch (E->Token ->getKind()) { + case NodeKind::IntegerLiteral: + Ty = getIntType(); break; - } - case NodeKind::RecordVariantDeclarationMember: - { - // TODO + case NodeKind::StringLiteral: + Ty = getStringType(); break; - } default: ZEN_UNREACHABLE } + break; } - popContext(); - - break; - } - - case NodeKind::RecordDeclaration: - { - auto Decl = static_cast(X); - - setContext(Decl->Ctx); - - std::vector Vars; - for (auto TE: Decl->Vars) { - auto TV = createRigidVar(TE->Name->getCanonicalText()); - Decl->Ctx->TVs->emplace(TV); - Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type); - Vars.push_back(TV); + case NodeKind::CallExpression: + { + auto E = static_cast(Expr); + auto RetTy = createTVar(); + Type* FunTy = RetTy; + for (auto It = E->Args.end(); It-- != E->Args.begin();) { + auto [ArgOut, ArgTy] = inferExpr(Env, *It, RetTy); + mergeTo(Out, ArgOut); + FunTy = new TFun(ArgTy, FunTy); + } + auto FunOut = checkExpr(Env, E->Function, FunTy, RetTy); + mergeTo(Out, FunOut); + Ty = RetTy; + break; } - auto Name = Decl->Name->getCanonicalText(); - auto Ty = createConType(Name); - - // Must be added early so we can create recursive types - Decl->Ctx->Parent->Env.add(Name, new Forall(Ty), SymKind::Type); - - Type* RetTy = Ty; - for (auto TV: Vars) { - RetTy = new Type(TApp(RetTy, TV)); + case NodeKind::InfixExpression: + { + auto E = static_cast(Expr); + auto [LeftOut, LeftTy] = inferExpr(Env, E->Left, RetTy); + mergeTo(Out, LeftOut); + auto [RightOut, RightTy] = inferExpr(Env, E->Right, RetTy); + mergeTo(Out, RightOut); + auto Name = E->Operator.getCanonicalText(); + auto Match = Env.lookup(Name, SymbolKind::Var); + if (Match == nullptr) { + DE.add(Name, E->Operator); + return { Out, createTVar() }; + } + auto RetTy = createTVar(); + auto FunTy = new TFun(LeftTy, new TFun(RightTy, RetTy)); + Out.push_back(new CTypesEqual(FunTy, instantiate(Match), E)); + Ty = RetTy; + break; } - // Corresponds to the logic of one branch of a VariantDeclarationMember - Type* FieldsTy = new Type(TNil()); - for (auto Field: Decl->Fields) { - FieldsTy = new Type( - TField( - Field->Name->getCanonicalText(), - new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))), - FieldsTy - ) - ); - } - Decl->Ctx->Parent->Env.add( - Name, - new Forall( - Decl->Ctx->TVs, - Decl->Ctx->Constraints, - new Type(TArrow(FieldsTy, RetTy)) - ), - SymKind::Var - ); - - popContext(); - - break; - } + // TODO LambdaExpression default: ZEN_UNREACHABLE } + Expr->setType(Ty); + + return { Out, Ty }; } -void Checker::initialize(Node* N) { - - struct Init : public CSTVisitor { - - Checker& C; - - std::stack Contexts; - - InferContext* createDerivedContext() { - return C.createInferContext(Contexts.top()); - } - - void visitVariantDeclaration(VariantDeclaration* Decl) { - Decl->Ctx = createDerivedContext(); - } - - void visitRecordDeclaration(RecordDeclaration* Decl) { - Decl->Ctx = createDerivedContext(); - } - - void visitMatchCase(MatchCase* C) { - C->Ctx = createDerivedContext(); - Contexts.push(C->Ctx); - visitEachChild(C); - Contexts.pop(); - } - - void visitSourceFile(SourceFile* SF) { - SF->Ctx = C.createInferContext(); - Contexts.push(SF->Ctx); - visitEachChild(SF); - Contexts.pop(); - } - - void visitFunctionDeclaration(FunctionDeclaration* Func) { - Func->Ctx = createDerivedContext(); - Contexts.push(Func->Ctx); - visitEachChild(Func); - Contexts.pop(); - } - - // void visitVariableDeclaration(VariableDeclaration* Var) { - // Var->Ctx = Contexts.top(); - // visitEachChild(Var); - // } - - }; - - Init I { {}, *this }; - I.visit(N); - -} - -void Checker::forwardDeclareFunctionDeclaration(FunctionDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { - - // std::cerr << "declare " << Let->getNameAsString() << std::endl; - - setContext(Let->Ctx); - - auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { - auto Id = Class->Name->getCanonicalText(); - auto Ctx = &getContext(); - std::vector Out; - for (auto TE: Class->TypeVars) { - auto Name = TE->Name->getCanonicalText(); - auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); - TV->asVar().Context.emplace(Id); - Ctx->Env.add(Name, new Forall(TV), SymKind::Type); - Out.push_back(TV); - } - return Out; - }; - - // If declaring a let-declaration inside a type class declaration, - // we need to mark that the let-declaration requires this class. - // This marking is set on the rigid type variables of the class, which - // are then added to this local type environment. - if (Let->isClass()) { - addClassVars(static_cast(Let->Parent), true); +void Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& Out) { + switch (P->getKind()) { + case NodeKind::BindPattern: + { + auto Q = static_cast(P); + // TODO Make a TypedNode out of a Pattern? + Out.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var); + break; + } + default: + ZEN_UNREACHABLE } +} - // Here we infer the primary type of the let declaration. If there's a - // type assert, that assert should be authoritative so we use that. - // Otherwise, the type is not further specified and we create a new - // unification variable. +std::tuple Checker::inferTypeExpr(TypeEnv& Env, TypeExpression* TE) { + + ConstraintSet Out; Type* Ty; - if (Let->hasTypeAssert()) { - Ty = inferTypeExpression(Let->getTypeAssert()->TypeExpression); - } else { - Ty = createTypeVar(); - } - Let->setType(Ty); - // If declaring a let-declaration inside a type instance declaration, - // we need to perform some work to make sure the type asserts of the - // corresponding let-declaration in the type class declaration are - // accounted for. - if (Let->isInstance()) { + switch (TE->getKind()) { - auto Instance = static_cast(Let->Parent); - auto Class = cast(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class)); - - if (Class == nullptr) { - // TODO print diagnostic - // DE.add(Instance->Name->getCanonicalText()); - goto after_isinstance; - } - - auto Decl = Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var); - - if (Decl == nullptr) { - - // TODO print diagnostic - // DE.add(Let->getNameAsStrings), Let->getName()); - goto after_isinstance; - - } - - if (!isa(Decl)) { - - // TODO print diagnostic - // DE.add(Decl); - goto after_isinstance; - - } - - auto FuncDecl = cast(Decl); - - auto Params = addClassVars(Class, false); - - // The type asserts in the type class declaration might make use of - // the type parameters of the type class declaration, so it is - // important to make them available in the type environment. Moreover, - // we will be unifying them with the actual types declared in the - // instance declaration, so we keep track of them. - // std::vector Params; - // TVSub Sub; - // for (auto TE: Class->TypeVars) { - // auto TV = createTypeVar(); - // Sub.emplace(cast(TE->getType()), TV); - // Params.push_back(TV); - // } - - // Here we do the actual unification of e.g. Eq a with Eq Bool. The - // unification variables we created previously will be unified with - // e.g. Bool, which causes the type assert to also collapse to e.g. - // Bool -> Bool -> Bool. - for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { - makeEqual(Param, TE->getType(), TE); - } - - // It would be very strange if there was no type assert in the type - // class let-declaration but we rather not let the compiler crash if that happens. - if (FuncDecl->hasTypeAssert()) { - // Note that we can't do SigLet->TypeAssert->TypeExpression->getType() - // because we need to re-generate the type within the local context of - // this let-declaration. - // TODO make CEqual accept multiple nodes - makeEqual(Ty, inferTypeExpression(FuncDecl->getTypeAssert()->TypeExpression), Let); - } - - } - -after_isinstance: - - if (Let->hasBody()) { - switch (Let->getBody()->getKind()) { - case NodeKind::LetExprBody: - break; - case NodeKind::LetBlockBody: + case NodeKind::ReferenceTypeExpression: { - auto Block = static_cast(Let->getBody()); - Let->Ctx->ReturnType = createTypeVar(); - for (auto Element: Block->Elements) { - forwardDeclare(Element); + auto E = static_cast(TE); + auto Name = E->Name->getCanonicalText(); + auto Match = Env.lookup(Name, SymbolKind::Type); + if (Match == nullptr) { + DE.add(Name, E->Name); + Ty = createTVar(); + } else { + Ty = instantiate(Match); } break; } - default: - ZEN_UNREACHABLE - } + + default: + ZEN_UNREACHABLE + } - if (!Let->isInstance()) { - Let->Ctx->Parent->Env.add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty), SymKind::Var); - } + TE->setType(Ty); + return { Out, Ty }; } -void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) { +ConstraintSet Checker::inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaration* D) { - // std::cerr << "infer " << Decl->getNameAsString() << std::endl; + auto TA = D->getTypeAssert(); + auto Params = D->getParams(); + auto Body = D->getBody(); - auto OldCtx = ActiveContext; - setContext(Decl->Ctx); + ConstraintSet Out; - std::vector ParamTypes; - Type* RetType; + TypeEnv NewEnv { Env }; - for (auto Param: Decl->getParams()) { - ParamTypes.push_back(inferPattern(Param->Pattern)); + auto RetTy = createTVar(); + + Type* Ty = RetTy; + for (auto It = Params.end(); It-- != Params.begin(); ) { + auto Param = *It; + auto ParamTy = createTVar(); + visitPattern(Param->Pattern, ParamTy, NewEnv); + Ty = new TFun(ParamTy, Ty); } - if (Decl->hasBody()) { - switch (Decl->getBody()->getKind()) { - case NodeKind::LetExprBody: + if (TA != nullptr) { + auto [TEOut, TETy] = inferTypeExpr(Env, TA->TypeExpression); + mergeTo(Out, TEOut); + Out.push_back(new CTypesEqual(Ty, TETy, TA->TypeExpression)); + } + + if (Body != nullptr) { + // TODO elminate BlockBody and replace with BlockExpr + ZEN_ASSERT(Body->getKind() == NodeKind::LetExprBody); + auto [BodyOut, BodyTy] = inferExpr(NewEnv, static_cast(Body)->Expression, RetTy); + mergeTo(Out, BodyOut); + Out.push_back(new CTypesEqual(RetTy, BodyTy, Body)); + } + + // Env.add(D->getNameAsString(), Ty, SymbolKind::Var); + + D->setType(Ty); + + return Out; +} + +ConstraintSet Checker::inferVariableDeclaration(TypeEnv& Env, VariableDeclaration* Decl) { + ConstraintSet Out; + // TODO + return Out; +} + +bool hasTypeVar(Type* Ty, TVar* TV) { + switch (TV->getKind()) { + case TypeKind::App: { - auto Expr = static_cast(Decl->getBody()); - RetType = inferExpression(Expr->Expression); - break; + auto T = static_cast(Ty); + return hasTypeVar(T->getLeft(), TV) + || hasTypeVar(T->getRight(), TV); } - case NodeKind::LetBlockBody: + case TypeKind::Con: + return false; + case TypeKind::Fun: { - auto Block = static_cast(Decl->getBody()); - RetType = Decl->Ctx->ReturnType; - for (auto Element: Block->Elements) { - infer(Element); + auto T = static_cast(Ty); + return hasTypeVar(T->getLeft(), TV) + || hasTypeVar(T->getRight(), TV); + } + case TypeKind::Var: + { + auto T = static_cast(Ty); + return T->find() == TV; + } + } +} + +bool TypeEnv::hasVar(TVar* TV) const { + for (auto [_, Scm]: Mapping) { + if (Scm->Unbound.count(TV)) { + // FIXME + ZEN_UNREACHABLE + } + if (hasTypeVar(Scm->getType(), TV)) { + return true; + } + } + return false; +} + +auto getUnbound(const TypeEnv& Env, Type* Ty) { + struct Visitor : public TypeVisitor { + const TypeEnv& Env; + Visitor(const TypeEnv& Env): + Env(Env) {} + std::vector Out; + void visitVar(TVar* TV) { + if (!Env.hasVar(TV)) { + Out.push_back(TV); + } + } + } V { Env }; + V.visit(Ty); + return V.Out; +} + +ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector& Elements, Type* RetTy) { + + using Graph = zen::hash_graph; + + TypeEnv NewEnv { Env }; + + Graph G; + + std::function populate = [&](auto From, auto N) { + struct Visitor : CSTVisitor { + Graph& G; + Node* From; + void visitReferenceExpression(ReferenceExpression* E) { + auto To = E->getScope()->lookup(E->getSymbolPath()); + if (isa(To)) { + To = To->Parent; + } + if (To != nullptr) { + G.add_edge(From, To); } - break; } - default: - ZEN_UNREACHABLE + } V { {}, G, From }; + V.visit(N); + }; + + std::vector Stmts; + + for (auto Element: Elements) { + if (isa(Element)) { + auto M = static_cast(Element); + G.add_vertex(Element); + if (M->hasBody()) { + populate(M, M->getBody()); + } + } else if (isa(Element)) { + auto M = static_cast(Element); + G.add_vertex(Element); + if (M->hasExpression()) { + populate(M, M->getExpression()); + } + } else { + Stmts.push_back(cast(Element)); } - } else { - RetType = createTypeVar(); } - makeEqual(Decl->getType(), Type::buildArrow(ParamTypes, RetType), Decl); + for (auto Nodes: zen::toposort(G)) { + ConstraintSet Out; + for (auto N: Nodes) { + if (isa(N)) { + mergeTo(Out, inferFunctionDeclaration(Env, static_cast(N))); + } else if (isa(N)) { + mergeTo(Out, inferVariableDeclaration(Env, static_cast(N))); + } else { + ZEN_UNREACHABLE + } + } + solve(Out); + for (auto N: Nodes) { + if (isa(N)) { + auto M = static_cast(N); + auto Unbound = getUnbound(Env, cast(N)->getType()); + Env.add( + M->getNameAsString(), + new TypeScheme { { Unbound.begin(), Unbound.end() }, M->getType() }, + SymbolKind::Var + ); + } + } + } - setContext(OldCtx); + ConstraintSet Out; + + for (auto Stmt: Stmts) { + mergeTo(Out, inferElement(Env, Stmt, RetTy)); + } + + return Out; } -void Checker::infer(Node* N) { +ConstraintSet Checker::inferElement(TypeEnv& Env, Node* N, Type* RetTy) { switch (N->getKind()) { - case NodeKind::SourceFile: - { - auto File = static_cast(N); - for (auto Element: File->Elements) { - infer(Element); - } - break; - } - - case NodeKind::ClassDeclaration: - { - auto Decl = static_cast(N); - for (auto Element: Decl->Elements) { - infer(Element); - } - break; - } - - case NodeKind::InstanceDeclaration: - { - auto Decl = static_cast(N); - for (auto Element: Decl->Elements) { - infer(Element); - } - break; - } - - case NodeKind::VariantDeclaration: - case NodeKind::RecordDeclaration: - // Nothing to do for a type-level declaration - break; - - case NodeKind::IfStatement: - { - auto IfStmt = static_cast(N); - for (auto Part: IfStmt->Parts) { - if (Part->Test != nullptr) { - makeEqual(BoolType, inferExpression(Part->Test), Part->Test); - } - for (auto Element: Part->Elements) { - infer(Element); - } - } - break; - } - - case NodeKind::ReturnStatement: - { - auto RetStmt = static_cast(N); - Type* ReturnType; - if (RetStmt->Expression) { - makeEqual(inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression); - } else { - ReturnType = UnitType; - makeEqual(UnitType, getReturnType(), N); - } - break; - } - case NodeKind::PrefixFunctionDeclaration: case NodeKind::InfixFunctionDeclaration: case NodeKind::SuffixFunctionDeclaration: case NodeKind::NamedFunctionDeclaration: - { - auto Decl = static_cast(N); - if (Decl->Visited) { - break; - } - Decl->IsCycleActive = true; - Decl->Visited = true; - inferFunctionDeclaration(Decl); - Decl->IsCycleActive = false; - break; - } - - case NodeKind::VariableDeclaration: - { - auto Decl = static_cast(N); - auto Ty = Decl->getType(); - if (Decl->Body) { - ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); - auto E = static_cast(Decl->Body); - auto Ty2 = inferExpression(E->Expression); - makeEqual(Ty, Ty2, Decl); - } - auto Ty3 = inferPattern(Decl->Pattern); - makeEqual(Ty, Ty3, Decl); - break; - } + return inferFunctionDeclaration(Env, static_cast(N)); case NodeKind::ExpressionStatement: - { - auto ExprStmt = static_cast(N); - inferExpression(ExprStmt->Expression); - break; - } + { + auto M = static_cast(N); + auto [Out, _] = inferExpr(Env, M->Expression, RetTy); + return Out; + } + + case NodeKind::ReturnStatement: + { + auto M = static_cast(N); + if (!M->hasExpression()) { + return {}; + } + auto [ValOut, ValTy] = inferExpr(Env, M->getExpression(), RetTy); + return { new CTypesEqual(ValTy, RetTy, N) }; + } default: ZEN_UNREACHABLE @@ -743,1243 +439,142 @@ void Checker::infer(Node* N) { } -Type* Checker::createConType(ByteString Name) { - return new Type(TCon(NextConTypeId++, Name)); +ConstraintSet Checker::inferSourceFile(TypeEnv& Env, SourceFile* SF) { + return inferMany(Env, SF->Elements, nullptr); } -Type* Checker::createRigidVar(ByteString Name) { - auto TV = new Type(TVar(VarKind::Rigid, NextTypeVarId++, {}, Name, {{}})); - getContext().TVs->emplace(TV); - return TV; -} +ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected, Type* RetTy) { -Type* Checker::createTypeVar() { - auto TV = new Type(TVar(VarKind::Unification, NextTypeVarId++, {})); - getContext().TVs->emplace(TV); - return TV; -} - -InferContext* Checker::createInferContext(InferContext* Parent, TVSet* TVs, ConstraintSet* Constraints) { - auto Ctx = new InferContext; - Ctx->Parent = Parent; - Ctx->TVs = new TVSet; - Ctx->Constraints = new ConstraintSet; - return Ctx; -} - -Type* Checker::instantiate(Scheme* Scm, Node* Source) { - - switch (Scm->getKind()) { - - case SchemeKind::Forall: - { - auto F = static_cast(Scm); - - TVSub Sub; - for (auto TV: *F->TVs) { - auto Fresh = createTypeVar(); - // std::cerr << describe(TV) << " => " << describe(Fresh) << std::endl; - Fresh->asVar().Context = TV->asVar().Context; - Sub[TV] = Fresh; - } - - for (auto Constraint: *F->Constraints) { - - // FIXME improve this - if (Constraint->getKind() == ConstraintKind::Equal) { - auto Eq = static_cast(Constraint); - Eq->Left = solveType(Eq->Left); - Eq->Right = solveType(Eq->Right); - } - - auto NewConstraint = Constraint->substitute(Sub); - - // This makes error messages prettier by relating the typing failure - // to the call site rather than the definition. - if (NewConstraint->getKind() == ConstraintKind::Equal) { - auto Eq = static_cast(Constraint); - Eq->Source = Source; - } - - addConstraint(NewConstraint); - } - - // This call to solve happens because constraints may have already - // been solved, with some unification variables being erased. To make - // sure we instantiate unification variables that are still in use - // we solve before substituting. - return solveType(F->Type)->substitute(Sub); - } - - } - - ZEN_UNREACHABLE -} - -void Checker::inferConstraintExpression(ConstraintExpression* C) { - switch (C->getKind()) { - case NodeKind::TypeclassConstraintExpression: - { - auto D = static_cast(C); - std::vector Types; - for (auto TE: D->TEs) { - auto Ty = inferTypeExpression(TE); - Ty->asVar().Provided->emplace(D->Name->getCanonicalText()); - Types.push_back(Ty); - } - break; - } - case NodeKind::EqualityConstraintExpression: - { - auto D = static_cast(C); - makeEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C); - break; - } - default: - ZEN_UNREACHABLE - } -} - -Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) { - - switch (N->getKind()) { - - case NodeKind::ReferenceTypeExpression: - { - auto RefTE = static_cast(N); - auto Scm = lookup(RefTE->Name->getCanonicalText(), SymKind::Type); - Type* Ty; - if (Scm == nullptr) { - DE.add(RefTE->Name->getCanonicalText(), RefTE->Name); - Ty = createTypeVar(); - } else { - Ty = instantiate(Scm, RefTE); - } - N->setType(Ty); - return Ty; - } - - case NodeKind::AppTypeExpression: - { - auto AppTE = static_cast(N); - Type* Ty = inferTypeExpression(AppTE->Op, AutoVars); - for (auto Arg: AppTE->Args) { - Ty = new Type(TApp(Ty, inferTypeExpression(Arg, AutoVars))); - } - N->setType(Ty); - return Ty; - } - - case NodeKind::VarTypeExpression: - { - auto VarTE = static_cast(N); - auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type); - if (Ty == nullptr) { - if (!AutoVars || Config.typeVarsRequireForall()) { - DE.add(VarTE->Name->getCanonicalText(), VarTE->Name); - } - Ty = createRigidVar(VarTE->Name->getCanonicalText()); - addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type); - } - ZEN_ASSERT(Ty->isVar()); - N->setType(Ty); - return Ty; - } - - case NodeKind::RecordTypeExpression: - { - auto RecTE = static_cast(N); - auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil()); - for (auto [Field, Comma]: RecTE->Fields) { - Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty)); - } - N->setType(Ty); - return Ty; - } - - case NodeKind::TupleTypeExpression: - { - auto TupleTE = static_cast(N); - std::vector ElementTypes; - for (auto [TE, Comma]: TupleTE->Elements) { - ElementTypes.push_back(inferTypeExpression(TE, AutoVars)); - } - auto Ty = new Type(TTuple(ElementTypes)); - N->setType(Ty); - return Ty; - } - - case NodeKind::NestedTypeExpression: - { - auto NestedTE = static_cast(N); - auto Ty = inferTypeExpression(NestedTE->TE, AutoVars); - N->setType(Ty); - return Ty; - } - - case NodeKind::ArrowTypeExpression: - { - auto ArrowTE = static_cast(N); - std::vector ParamTypes; - for (auto ParamType: ArrowTE->ParamTypes) { - ParamTypes.push_back(inferTypeExpression(ParamType, AutoVars)); - } - auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, AutoVars); - auto Ty = Type::buildArrow(ParamTypes, ReturnType); - N->setType(Ty); - return Ty; - } - - case NodeKind::QualifiedTypeExpression: - { - auto QTE = static_cast(N); - for (auto [C, Comma]: QTE->Constraints) { - inferConstraintExpression(C); - } - auto Ty = inferTypeExpression(QTE->TE, AutoVars); - N->setType(Ty); - return Ty; - } - - default: - ZEN_UNREACHABLE - - } -} - -Type* sortRow(Type* Ty) { - std::map Fields; - while (Ty->isField()) { - auto& Field = Ty->asField(); - Fields.emplace(Field.Name, Ty); - Ty = Field.RestTy; - } - for (auto [Name, Field]: Fields) { - Ty = new Type(TField(Name, Field->asField().Ty, Ty)); - } - return Ty; -} - -Type* Checker::inferExpression(Expression* X) { - - Type* Ty; - - for (auto A: X->Annotations) { - if (A->getKind() == NodeKind::TypeAssertAnnotation) { - inferTypeExpression(static_cast(A)->TE); - } - } - - switch (X->getKind()) { - - case NodeKind::MatchExpression: - { - auto Match = static_cast(X); - Type* ValTy; - if (Match->Value) { - ValTy = inferExpression(Match->Value); - } else { - ValTy = createTypeVar(); - } - Ty = createTypeVar(); - for (auto Case: Match->Cases) { - auto OldCtx = &getContext(); - setContext(Case->Ctx); - auto PattTy = inferPattern(Case->Pattern); - makeEqual(PattTy, ValTy, Case); - auto ExprTy = inferExpression(Case->Expression); - makeEqual(ExprTy, Ty, Case->Expression); - setContext(OldCtx); - } - if (!Match->Value) { - Ty = new Type(TArrow(ValTy, Ty)); - } - break; - } - - case NodeKind::RecordExpression: - { - auto Record = static_cast(X); - Ty = new Type(TNil()); - for (auto [Field, Comma]: Record->Fields) { - Ty = new Type(TField( - Field->Name->getCanonicalText(), - new Type(TPresent(inferExpression(Field->getExpression()))), - Ty - )); - } - Ty = sortRow(Ty); - break; - } + switch (Expr->getKind()) { case NodeKind::LiteralExpression: - { - auto Const = static_cast(X); - Ty = inferLiteral(Const->Token); - break; - } - - case NodeKind::ReferenceExpression: - { - auto Ref = static_cast(X); - auto Name = Ref->Name.getCanonicalText(); - ZEN_ASSERT(Ref->ModulePath.empty()); - if (Ref->Name.isIdentifierAlt()) { - auto Scm = lookup(Name, SymKind::Var); - if (!Scm) { - DE.add(Name, Ref->Name); - Ty = createTypeVar(); - break; + { + auto E = static_cast(Expr); + switch (E->Token->getKind()) { + case NodeKind::IntegerLiteral: + if (*Expected == *getIntType()) { + return {}; + } + break; + case NodeKind::StringLiteral: + if (*Expected == *getStringType()) { + return {}; + } + break; + default: + break; } - Ty = instantiate(Scm, X); - break; - } - auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); - if (!Target) { - DE.add(Name, Ref->Name); - Ty = createTypeVar(); - break; - } - if (isa(Target)) { - auto Let = static_cast(Target); - if (Let->IsCycleActive) { - Ty = Let->getType(); - break; - } - if (!Let->Visited) { - infer(Let); - } - } - auto Scm = lookup(Name, SymKind::Var); - ZEN_ASSERT(Scm); - Ty = instantiate(Scm, X); - break; - } + } - case NodeKind::CallExpression: - { - auto Call = static_cast(X); - auto OpTy = inferExpression(Call->Function); - Ty = createTypeVar(); - std::vector ArgTypes; - for (auto Arg: Call->Args) { - ArgTypes.push_back(inferExpression(Arg)); + // TODO + // case NodeKind::FunctionExpression: + + default: + { + auto [Out, Actual] = inferExpr(Env, Expr, RetTy); + Out.push_back(new CTypesEqual(Actual, Expected, Expr)); + return Out; } - makeEqual(OpTy, Type::buildArrow(ArgTypes, Ty), X); - break; - } - case NodeKind::InfixExpression: - { - auto Infix = static_cast(X); - auto Scm = lookup(Infix->Operator.getCanonicalText(), SymKind::Var); - if (Scm == nullptr) { - DE.add(Infix->Operator.getCanonicalText(), Infix->Operator); - Ty = createTypeVar(); - break; - } - auto OpTy = instantiate(Scm, Infix->Operator); - Ty = createTypeVar(); - std::vector ArgTys; - ArgTys.push_back(inferExpression(Infix->Left)); - ArgTys.push_back(inferExpression(Infix->Right)); - makeEqual(Type::buildArrow(ArgTys, Ty), OpTy, X); - break; - } + } - case NodeKind::TupleExpression: - { - auto Tuple = static_cast(X); - std::vector Types; - for (auto [E, Comma]: Tuple->Elements) { - Types.push_back(inferExpression(E)); - } - Ty = new Type(TTuple(Types)); - break; - } +} - case NodeKind::MemberExpression: - { - auto Member = static_cast(X); - auto ExprTy = inferExpression(Member->E); - switch (Member->Name->getKind()) { - case NodeKind::IntegerLiteral: +void Checker::solve(const std::vector& Constraints) { + for (auto C: Constraints) { + switch (C->getKind()) { + case ConstraintKind::TypesEqual: { - auto I = static_cast(Member->Name); - Ty = createTypeVar(); - addConstraint(new CField(ExprTy, I->asInt(), Ty, Member)); + auto D = static_cast(C); + unifyTypeType(D->getLeft(), D->getRight(), D->getOrigin()); break; } - case NodeKind::Identifier: - { - auto K = static_cast(Member->Name); - Ty = createTypeVar(); - auto RestTy = createTypeVar(); - makeEqual(new Type(TField(K->getCanonicalText(), Ty, RestTy)), ExprTy, Member); - break; - } - default: - ZEN_UNREACHABLE - } - break; } - - case NodeKind::NestedExpression: - { - auto Nested = static_cast(X); - Ty = inferExpression(Nested->Inner); - break; - } - - default: - ZEN_UNREACHABLE - } - - // Ty = find(Ty); - X->setType(Ty); - return Ty; } -RecordPatternField* getRestField(std::vector> Fields) { - for (auto [Field, Comma]: Fields) { - if (Field->DotDot) { - return Field; - } - } - return nullptr; -} - -Type* Checker::inferPattern( - Pattern* Pattern, - ConstraintSet* Constraints, - TVSet* TVs -) { - - switch (Pattern->getKind()) { - - case NodeKind::BindPattern: - { - auto P = static_cast(Pattern); - auto Ty = createTypeVar(); - addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty), SymKind::Var); - return Ty; - } - - case NodeKind::NamedTuplePattern: - { - auto P = static_cast(Pattern); - auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var); - std::vector ElementTypes; - for (auto P2: P->Patterns) { - ElementTypes.push_back(inferPattern(P2, Constraints, TVs)); - } - if (!Scm) { - DE.add(P->Name->getCanonicalText(), P->Name); - return createTypeVar(); - } - auto Ty = instantiate(Scm, P); - auto RetTy = createTypeVar(); - makeEqual(Ty, Type::buildArrow(ElementTypes, RetTy), P); - return RetTy; - } - - case NodeKind::RecordPattern: - { - auto P = static_cast(Pattern); - auto RestField = getRestField(P->Fields); - Type* RecordTy; - if (RestField == nullptr) { - RecordTy = new Type(TNil()); - } else if (RestField->Pattern) { - RecordTy = inferPattern(RestField->Pattern); - } else { - RecordTy = createTypeVar(); - } - for (auto [Field, Comma]: P->Fields) { - if (Field->DotDot) { - continue; - } - Type* FieldTy; - if (Field->Pattern) { - FieldTy = inferPattern(Field->Pattern, Constraints, TVs); - } else { - FieldTy = createTypeVar(); - addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var); - } - RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy)); - } - return RecordTy; - } - - case NodeKind::NamedRecordPattern: - { - auto P = static_cast(Pattern); - auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var); - if (Scm == nullptr) { - DE.add(P->Name->getCanonicalText(), P->Name); - return createTypeVar(); - } - auto RestField = getRestField(P->Fields); - Type* RecordTy; - if (RestField == nullptr) { - RecordTy = new Type(TNil()); - } else if (RestField->Pattern) { - RecordTy = inferPattern(RestField->Pattern); - } else { - RecordTy = createTypeVar(); - } - for (auto [Field, Comma]: P->Fields) { - if (Field->DotDot) { - continue; - } - Type* FieldTy; - if (Field->Pattern) { - FieldTy = inferPattern(Field->Pattern, Constraints, TVs); - } else { - FieldTy = createTypeVar(); - addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var); - } - RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy)); - } - auto Ty = instantiate(Scm, P); - auto RetTy = createTypeVar(); - makeEqual(Ty, new Type(TArrow(RecordTy, RetTy)), P); - return RetTy; - } - - case NodeKind::TuplePattern: - { - auto P = static_cast(Pattern); - std::vector ElementTypes; - for (auto [Element, Comma]: P->Elements) { - ElementTypes.push_back(inferPattern(Element)); - } - return new Type(TTuple(ElementTypes)); - } - - case NodeKind::ListPattern: - { - auto P = static_cast(Pattern); - auto ElementType = createTypeVar(); - for (auto [Element, Separator]: P->Elements) { - makeEqual(ElementType, inferPattern(Element), P); - } - return new Type(TApp(ListType, ElementType)); - } - - case NodeKind::NestedPattern: - { - auto P = static_cast(Pattern); - return inferPattern(P->P, Constraints, TVs); - } - - case NodeKind::LiteralPattern: - { - auto P = static_cast(Pattern); - return inferLiteral(P->Literal); - } - - default: - ZEN_UNREACHABLE - - } - -} - -Type* Checker::inferLiteral(Literal* L) { - Type* Ty; - switch (L->getKind()) { - case NodeKind::IntegerLiteral: - Ty = lookupMono("Int", SymKind::Type); - break; - case NodeKind::StringLiteral: - Ty = lookupMono("String", SymKind::Type); - break; - default: - ZEN_UNREACHABLE - } - ZEN_ASSERT(Ty != nullptr); - return Ty; -} - -void Checker::populate(SourceFile* SF) { - - struct Visitor : public CSTVisitor { - - Graph& RefGraph; - - std::stack Stack; - - void visitFunctionDeclaration(FunctionDeclaration* N) { - RefGraph.addVertex(N); - Stack.push(N); - visitEachChild(N); - Stack.pop(); - } - - void visitVariableDeclaration(VariableDeclaration* N) { - RefGraph.addVertex(N); - Stack.push(N); - visitEachChild(N); - Stack.pop(); - } - - void visitReferenceExpression(ReferenceExpression* N) { - auto Ref = static_cast(N); - auto Def = Ref->getScope()->lookup(Ref->getSymbolPath()); - if (Def == nullptr) { - // Name lookup failures will be reported directly in inferExpression(). - return; - } - ZEN_ASSERT(isa(Def) || isa(Def) || isa(Def)); - // This case ensures that a deeply nested structure that references a - // parameter of a parent node but is not referenced itself is correctly handled. - // Note that the edge goes from the parent let to the parameter. This is normal. - // if (Def->getKind() == NodeKind::Parameter) { - // RefGraph.addEdge(Stack.top(), Def->Parent); - // return; - // } - if (Stack.empty()) { - // An empty stack means we are traversing the toplevel of the source - // file, in which case we don't have anyting to connect with. - return; - } - RefGraph.addEdge(Def, Stack.top()); - } - - }; - - Visitor V { {}, RefGraph }; - V.visit(SF); - -} - -Type* Checker::getType(TypedNode *Node) { - auto Ty = Node->getType(); - if (Node->Flags & NodeFlags_TypeIsSolved) { - return Ty; - } - Ty = solveType(Ty); - Node->setType(Ty); - Node->Flags |= NodeFlags_TypeIsSolved; - return Ty; -} - -void Checker::check(SourceFile *SF) { - initialize(SF); - setContext(SF->Ctx); - addBinding("String", new Forall(StringType), SymKind::Type); - addBinding("Int", new Forall(IntType), SymKind::Type); - addBinding("Bool", new Forall(BoolType), SymKind::Type); - addBinding("List", new Forall(ListType), SymKind::Type); - addBinding("True", new Forall(BoolType), SymKind::Var); - addBinding("False", new Forall(BoolType), SymKind::Var); - auto A = createTypeVar(); - addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)), SymKind::Var); - addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); - addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); - addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); - addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); - populate(SF); - forwardDeclare(SF); - auto SCCs = RefGraph.strongconnect(); - for (auto Nodes: SCCs) { - auto TVs = new TVSet; - auto Constraints = new ConstraintSet; - for (auto N: Nodes) { - if (!isa(N)) { - continue; - } - auto Decl = static_cast(N); - forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); - } - } - setContext(SF->Ctx); - infer(SF); - - // Important because otherwise some logic for some optimisations will kick in that are no longer active. - ActiveContext = nullptr; - - solve(new CMany(*SF->Ctx->Constraints)); - - class Visitor : public CSTVisitor { - - Checker& C; - - public: - - Visitor(Checker& C): - C(C) {} - - void visitAnnotation(Annotation* A) { - - } - - void visitExpression(Expression* X) { - C.getType(X); - } - - } V(*this); - - V.visit(SF); -} - -void Checker::solve(Constraint* Constraint) { - - Queue.push_back(Constraint); - bool DidJoin = false; - std::deque NextQueue; - - while (true) { - - if (Queue.empty()) { - if (NextQueue.empty() || !DidJoin) { - break; - } - DidJoin = false; - std::swap(Queue, NextQueue); - } - - auto Constraint = Queue.front(); - Queue.pop_front(); - - switch (Constraint->getKind()) { - - case ConstraintKind::Empty: - break; - - case ConstraintKind::Field: - { - auto Field = static_cast(Constraint); - auto MaybeTuple = Field->TupleTy->find(); - if (MaybeTuple->isTuple()) { - auto& Tuple = MaybeTuple->asTuple(); - if (Field->I >= Tuple.ElementTypes.size()) { - DE.add(MaybeTuple, Field->I, Field->Source); - } else { - auto ElementTy = Tuple.ElementTypes[Field->I]; - unify(ElementTy, Field->FieldTy, Field->Source); - } - } else if (MaybeTuple->isVar()) { - NextQueue.push_back(Constraint); - } else { - DE.add(MaybeTuple, Field->Source); - } - break; - } - - case ConstraintKind::Many: - { - auto Many = static_cast(Constraint); - for (auto Constraint: Many->Elements) { - Queue.push_back(Constraint); - } - break; - } - - case ConstraintKind::Equal: - { - auto Equal = static_cast(Constraint); - if (unify(Equal->Left, Equal->Right, Equal->Source)) { - DidJoin = true; - } - break; - } - - } - - } - -} - -bool assignableTo(Type* A, Type* B) { - if (A->isCon() && B->isCon()) { - auto& Con1 = A->asCon(); - auto& Con2 = B->asCon(); - if (Con1.Id != Con2.Id) { - return false; - } - return true; - } - // TODO must handle a TApp - ZEN_UNREACHABLE -} - -class ArrowCursor { - - /// Types on this stack are guaranteed to be arrow types. - std::stack> Stack; - - TypePath& Path; - std::size_t I; - -public: - - ArrowCursor(Type* Arr, TypePath& Path): - Path(Path) { - Stack.push({ Arr, true }); - Path.push_back(Arr->getStartIndex()); - } - - Type* next() { - while (!Stack.empty()) { - auto& [Arrow, First] = Stack.top(); - auto& Index = Path.back(); - if (!First) { - Index.advance(Arrow); - } else { - First = false; - } - Type* Ty; - if (Index == Arrow->getEndIndex()) { - Path.pop_back(); - Stack.pop(); - continue; - } - Ty = Arrow->resolve(Index); - if (Ty->isArrow()) { - auto NewIndex = Arrow->getStartIndex(); - Stack.push({ Ty, true }); - Path.push_back(NewIndex); - } else { - return Ty; - } - } - return nullptr; - } - -}; - -struct Unifier { - - Checker& C; - // CEqual* Constraint; - Type* Left; - Type* Right; - Node* Source; - - // Internal state used by the unifier - ByteString CurrentFieldName; - TypePath LeftPath; - TypePath RightPath; - bool DidJoin = false; - - Type* getLeft() const { - return Left; - } - - Type* getRight() const { - return Right; - } - - Node* getSource() const { - return Source; - } - - bool unifyField(Type* A, Type* B, bool DidSwap); - - bool unify(Type* A, Type* B, bool DidSwap); - - bool unify() { - return unify(Left, Right, false); - } - - std::vector findInstanceContext(const TypeSig& Ty, TypeclassId& Class) { - auto Match = C.InstanceMap.find(Class); - std::vector S; - if (Match != C.InstanceMap.end()) { - for (auto Instance: Match->second) { - if (assignableTo(Ty.Orig, Instance->TypeExps[0]->getType())) { - std::vector S; - for (auto Arg: Ty.Args) { - TypeclassContext Classes; - // TODO - S.push_back(Classes); - } - return S; - } - } - } - C.DE.add(Class, Ty.Orig, getSource()); - for (auto Arg: Ty.Args) { - S.push_back({}); - } - return S; - } - - TypeSig getTypeSig(Type* Ty) { - Type* Op = nullptr; - std::vector Args; - std::function Visit = [&](Type* Ty) { - if (Ty->isApp()) { - Visit(Ty->asApp().Op); - Visit(Ty->asApp().Arg); - } else if (!Op) { - Op = Ty; - } else { - Args.push_back(Ty); - } - }; - Visit(Ty); - return TypeSig { Ty, Op, Args }; - } - - void propagateClasses(std::unordered_set& Classes, Type* Ty) { - if (Ty->isVar()) { - auto TV = Ty->asVar(); - for (auto Class: Classes) { - TV.Context.emplace(Class); - } - if (TV.isRigid()) { - for (auto Id: TV.Context) { - if (!TV.Provided->count(Id)) { - C.DE.add(TypeclassSignature { Id, { Ty } }, getSource()); - } - } - } - } else if (Ty->isCon() || Ty->isApp()) { - auto Sig = getTypeSig(Ty); - for (auto Class: Classes) { - propagateClassTycon(Class, Sig); - } - } else if (!Classes.empty()) { - C.DE.add(Ty, std::vector(Classes.begin(), Classes.end()), getSource()); - } - }; - - void propagateClassTycon(TypeclassId& Class, const TypeSig& Sig) { - auto S = findInstanceContext(Sig, Class); - for (auto [Classes, Arg]: zen::zip(S, Sig.Args)) { - propagateClasses(Classes, Arg); - } - }; - - /** - * Assign a type to a unification variable. - * - * If there are class constraints, those are propagated. - * - * If this type variable is solved during inference, it will be removed from - * the inference context. - * - * Other side effects may occur. - */ - void join(Type* TV, Type* Ty) { - - // std::cerr << describe(TV) << " => " << describe(Ty) << std::endl; - - TV->set(Ty); - - DidJoin = true; - - propagateClasses(TV->asVar().Context, Ty); - - // This is a very specific adjustment that is critical to the - // well-functioning of the infer/unify algorithm. When addConstraint() is - // called, it may decide to solve the constraint immediately during - // inference. If this happens, a type variable might get assigned a concrete - // type such as Int. We therefore never want the variable to be polymorphic - // and be instantiated with a fresh variable, as that would allow Bool to - // collide with Int. - // - // Should it get assigned another unification variable, that's OK too - // because then that variable is what matters and it will become the new - // (possibly polymorphic) variable. - if (C.ActiveContext) { - // std::cerr << "erase " << describe(TV) << std::endl; - auto TVs = C.ActiveContext->TVs; - TVs->erase(TV); - } - - } - -}; - -bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) { - if (A->isAbsent() && B->isAbsent()) { - return true; - } - if (B->isAbsent()) { - std::swap(A, B); - DidSwap = !DidSwap; - } - if (A->isAbsent()) { - auto& Present = B->asPresent(); - C.DE.add(CurrentFieldName, C.solveType(getLeft()), LeftPath, getSource()); - return false; - } - auto& Present1 = A->asPresent(); - auto& Present2 = B->asPresent(); - return unify(Present1.Ty, Present2.Ty, DidSwap); -}; - -bool Unifier::unify(Type* A, Type* B, bool DidSwap) { - +void Checker::unifyTypeType(Type* A, Type* B, Node* N) { A = A->find(); B = B->find(); - - auto unifyError = [&]() { - C.DE.add( - Left, - Right, - LeftPath, - RightPath, - Source - ); - }; - - auto pushLeft = [&](TypeIndex I) { - if (DidSwap) { - RightPath.push_back(I); - } else { - LeftPath.push_back(I); - } - }; - - auto popLeft = [&]() { - if (DidSwap) { - RightPath.pop_back(); - } else { - LeftPath.pop_back(); - } - }; - - auto pushRight = [&](TypeIndex I) { - if (DidSwap) { - LeftPath.push_back(I); - } else { - RightPath.push_back(I); - } - }; - - auto popRight = [&]() { - if (DidSwap) { - LeftPath.pop_back(); - } else { - RightPath.pop_back(); - } - }; - - auto swap = [&]() { - std::swap(A, B); - DidSwap = !DidSwap; - }; - - if (A->isVar() && B->isVar()) { - auto& Var1 = A->asVar(); - auto& Var2 = B->asVar(); - if (Var1.isRigid() && Var2.isRigid()) { - if (Var1.Id != Var2.Id) { - unifyError(); - return false; - } - return true; - } - Type* To; - Type* From; - if (Var1.isRigid() && Var2.isUni()) { - To = A; - From = B; - } else { - // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification - // Either way, Var1, being Unification, is a good candidate for being unified away - To = B; - From = A; - } - if (From->asVar().Id != To->asVar().Id) { - join(From, To); - } - return true; + if (A->getKind() == TypeKind::Var) { + auto TV = static_cast(A); + // TODO occurs check + TV->set(B); + return; } - - if (B->isVar()) { - swap(); + if (B->getKind() == TypeKind::Var) { + unifyTypeType(B, A, N); + return; } - - if (A->isVar()) { - - auto& TV = A->asVar(); - - // Rigid type variables can never unify with antything else than what we - // have already handled in the previous if-statement, so issue an error. - if (TV.isRigid()) { - unifyError(); - return false; + if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) { + auto C1 = static_cast(A); + auto C2 = static_cast(B); + if (C1->getName() == C2->getName()) { + return; } - - // Occurs check - if (B->hasTypeVar(A)) { - // NOTE Just like GHC, we just display an error message indicating that - // A cannot match B, e.g. a cannot match [a]. It looks much better - // than obsure references to an occurs check - unifyError(); - return false; - } - - join(A, B); - - return true; } - - if (A->isArrow() && B->isArrow()) { - auto& Arrow1 = A->asArrow(); - auto& Arrow2 = B->asArrow(); - bool Success = true; - LeftPath.push_back(TypeIndex::forArrowParamType()); - RightPath.push_back(TypeIndex::forArrowParamType()); - if (!unify(Arrow1.ParamType, Arrow2.ParamType, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - LeftPath.push_back(TypeIndex::forArrowReturnType()); - RightPath.push_back(TypeIndex::forArrowReturnType()); - if (!unify(Arrow1.ReturnType, Arrow2.ReturnType, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - return Success; + if (A->getKind() == TypeKind::Fun && B->getKind() == TypeKind::Fun) { + auto F1 = static_cast(A); + auto F2 = static_cast(B); + unifyTypeType(F1->getLeft(), F2->getLeft(), N); + unifyTypeType(F1->getRight(), F2->getRight(), N); + return; } - - if (A->isApp() && B->isApp()) { - auto& App1 = A->asApp(); - auto& App2 = B->asApp(); - bool Success = true; - LeftPath.push_back(TypeIndex::forAppOpType()); - RightPath.push_back(TypeIndex::forAppOpType()); - if (!unify(App1.Op, App2.Op, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - LeftPath.push_back(TypeIndex::forAppArgType()); - RightPath.push_back(TypeIndex::forAppArgType()); - if (!unify(App1.Arg, App2.Arg, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - return Success; - } - - if (A->isTuple() && B->isTuple()) { - auto& Tuple1 = A->asTuple(); - auto& Tuple2 = B->asTuple(); - if (Tuple1.ElementTypes.size() != Tuple2.ElementTypes.size()) { - unifyError(); - return false; - } - auto Count = Tuple1.ElementTypes.size(); - bool Success = true; - for (size_t I = 0; I < Count; I++) { - LeftPath.push_back(TypeIndex::forTupleElement(I)); - RightPath.push_back(TypeIndex::forTupleElement(I)); - if (!unify(Tuple1.ElementTypes[I], Tuple2.ElementTypes[I], DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - } - return Success; - } - - // if (A->isTupleIndex() || B->isTupleIndex()) { - // // Type(s) could not be simplified at the beginning of this function, - // // so we have to re-visit the constraint when there is more information. - // C.Queue.push_back(Constraint); - // return true; - // } - - // This does not work because it ignores the indices - // if (A->isTupleIndex() && B->isTupleIndex()) { - // auto Index1 = static_cast(A); - // auto Index2 = static_cast(B); - // return unify(Index1->Ty, Index2->Ty, Source); - // } - - if (A->isCon() && B->isCon()) { - auto& Con1 = A->asCon(); - auto& Con2 = B->asCon(); - if (Con1.Id != Con2.Id) { - unifyError(); - return false; - } - return true; - } - - if (A->isNil() && B->isNil()) { - return true; - } - - if (A->isField() && B->isField()) { - auto& Field1 = A->asField(); - auto& Field2 = B->asField(); - bool Success = true; - if (Field1.Name == Field2.Name) { - LeftPath.push_back(TypeIndex::forFieldType()); - RightPath.push_back(TypeIndex::forFieldType()); - CurrentFieldName = Field1.Name; - if (!unifyField(Field1.Ty, Field2.Ty, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - LeftPath.push_back(TypeIndex::forFieldRest()); - RightPath.push_back(TypeIndex::forFieldRest()); - if (!unify(Field1.RestTy, Field2.RestTy, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - return Success; - } - auto NewRestTy = new Type(TVar(VarKind::Unification, C.NextTypeVarId++)); - pushLeft(TypeIndex::forFieldRest()); - if (!unify(Field1.RestTy, new Type(TField(Field2.Name, Field2.Ty, NewRestTy)), DidSwap)) { - Success = false; - } - popLeft(); - pushRight(TypeIndex::forFieldRest()); - if (!unify(new Type(TField(Field1.Name, Field1.Ty, NewRestTy)), Field2.RestTy, DidSwap)) { - Success = false; - } - popRight(); - return Success; - } - - if (A->isNil() && B->isField()) { - swap(); - } - - if (A->isField() && B->isNil()) { - auto& Field = A->asField(); - bool Success = true; - pushLeft(TypeIndex::forFieldType()); - CurrentFieldName = Field.Name; - if (!unifyField(Field.Ty, new Type(TAbsent()), DidSwap)) { - Success = false; - } - popLeft(); - pushLeft(TypeIndex::forFieldRest()); - if (!unify(Field.RestTy, B, DidSwap)) { - Success = false; - } - popLeft(); - return Success; - } - - unifyError(); - return false; + DE.add(A, B, N); } -bool Checker::unify(Type* Left, Type* Right, Node* Source) { - // std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; - Unifier A { *this, Left, Right, Source }; - A.unify(); - return A.DidJoin; +void Checker::run(SourceFile* SF) { + TypeEnv Env; + Env.add("Int", getIntType(), SymbolKind::Type); + Env.add("Bool", getBoolType(), SymbolKind::Type); + Env.add("String", getStringType(), SymbolKind::Type); + Env.add("True", getBoolType(), SymbolKind::Var); + Env.add("False", getBoolType(), SymbolKind::Var); + Env.add("+", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var); + Env.add("-", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var); + auto Out = inferSourceFile(Env, SF); + solve(Out); +} + +Type* resolveType(Type* Ty) { + switch (Ty->getKind()) { + case TypeKind::App: + { + auto A = static_cast(Ty); + auto NewLeft = resolveType(A->getLeft()); + auto NewRight = resolveType(A->getRight()); + if (A->getLeft() == NewLeft && A->getRight() == NewRight) { + return Ty; + } + return new TApp(NewLeft, NewRight); + } + case TypeKind::Con: + return Ty; + case TypeKind::Var: + { + auto NewTy = Ty->find(); + if (NewTy->getKind() != TypeKind::Var) { + return resolveType(NewTy); + } else { + return NewTy; + } + } + case TypeKind::Fun: + { + auto F = static_cast(Ty); + auto NewLeft = resolveType(F->getLeft()); + auto NewRight = resolveType(F->getRight()); + if (F->getLeft() == NewLeft && F->getRight() == NewRight) { + return Ty; + } + return new TFun(NewLeft, NewRight); + } + } +} + +Type* Checker::getTypeOfNode(Node* N) { + auto M = cast(N); + return resolveType(M->getType()); } } - diff --git a/src/ConsolePrinter.cc b/src/ConsolePrinter.cc index 5da51c905..3b5ae378a 100644 --- a/src/ConsolePrinter.cc +++ b/src/ConsolePrinter.cc @@ -1,11 +1,9 @@ // FIXME writeExcerpt does not work well with the last line in a file -#include +#include #include -#include "zen/config.hpp" - #include "bolt/CST.hpp" #include "bolt/Type.hpp" #include "bolt/Diagnostics.hpp" @@ -182,6 +180,8 @@ static std::string describe(NodeKind Type) { return "a variant"; case NodeKind::MatchCase: return "a match-arm"; + case NodeKind::LetExprBody: + return "the body of a let-declaration"; default: ZEN_UNREACHABLE } @@ -199,79 +199,6 @@ static std::string describe(Token* T) { } } -std::string describe(const Type* Ty) { - Ty = Ty->find(); - switch (Ty->getKind()) { - case TypeKind::Var: - { - auto TV = Ty->asVar(); - if (TV.isRigid()) { - return *TV.Name; - } - return "a" + std::to_string(TV.Id); - } - case TypeKind::Arrow: - { - auto Y = Ty->asArrow(); - std::ostringstream Out; - Out << describe(Y.ParamType) << " -> " << describe(Y.ReturnType); - return Out.str(); - } - case TypeKind::Con: - { - auto Y = Ty->asCon(); - return Y.DisplayName; - } - case TypeKind::App: - { - auto Y = Ty->asApp(); - return describe(Y.Op) + " " + describe(Y.Arg); - } - case TypeKind::Tuple: - { - std::ostringstream Out; - auto Y = Ty->asTuple(); - Out << "("; - if (Y.ElementTypes.size()) { - auto Iter = Y.ElementTypes.begin(); - Out << describe(*Iter++); - while (Iter != Y.ElementTypes.end()) { - Out << ", " << describe(*Iter++); - } - } - Out << ")"; - return Out.str(); - } - case TypeKind::Nil: - return "{}"; - case TypeKind::Absent: - return "Abs"; - case TypeKind::Present: - { - auto Y = Ty->asPresent(); - return describe(Y.Ty); - } - case TypeKind::Field: - { - auto Y = Ty->asField(); - std::ostringstream out; - out << "{ " << Y.Name << ": " << describe(Y.Ty); - Ty = Y.RestTy; - while (Ty->getKind() == TypeKind::Field) { - auto Y = Ty->asField(); - out << "; " + Y.Name + ": " + describe(Y.Ty); - Ty = Y.RestTy; - } - if (Ty->getKind() != TypeKind::Nil) { - out << "; " + describe(Ty); - } - out << " }"; - return out.str(); - } - } - ZEN_UNREACHABLE -} - void writeForegroundANSI(Color C, std::ostream& Out) { switch (C) { case Color::None: @@ -533,153 +460,6 @@ void ConsolePrinter::writeBinding(const ByteString& Name) { write("'"); } -void ConsolePrinter::writeType(const Type* Ty) { - TypePath Path; - writeType(Ty, Path); -} - -void ConsolePrinter::writeType(const Type* Ty, const TypePath& Underline) { - - setForegroundColor(Color::Green); - - class TypePrinter : public ConstTypeVisitor { - - TypePath Path; - ConsolePrinter& W; - const TypePath& Underline; - - public: - - TypePrinter(ConsolePrinter& W, const TypePath& Underline): - W(W), Underline(Underline) {} - - bool shouldUnderline() const { - return !Underline.empty() && Path == Underline; - } - - void enterType(const Type* Ty) override { - if (shouldUnderline()) { - W.setUnderline(true); - } - } - - void exitType(const Type* Ty) override { - if (shouldUnderline()) { - W.setUnderline(false); // FIXME Should set to old value - } - } - - void visitAppType(const TApp& Ty) override { - Path.push_back(TypeIndex::forAppOpType()); - visit(Ty.Op); - Path.pop_back(); - W.write(" "); - Path.push_back(TypeIndex::forAppArgType()); - visit(Ty.Arg); - Path.pop_back(); - } - - void visitVarType(const TVar& Ty) override { - if (Ty.isRigid()) { - W.write(*Ty.Name); - return; - } - W.write("a"); - W.write(Ty.Id); - } - - void visitConType(const TCon& Ty) override { - W.write(Ty.DisplayName); - } - - void visitArrowType(const TArrow& Ty) override { - Path.push_back(TypeIndex::forArrowParamType()); - visit(Ty.ParamType); - Path.pop_back(); - W.write(" -> "); - Path.push_back(TypeIndex::forArrowReturnType()); - visit(Ty.ReturnType); - Path.pop_back(); - } - - void visitTupleType(const TTuple& Ty) override { - W.write("("); - if (Ty.ElementTypes.size()) { - auto Iter = Ty.ElementTypes.begin(); - Path.push_back(TypeIndex::forTupleElement(0)); - visit(*Iter++); - Path.pop_back(); - std::size_t I = 1; - while (Iter != Ty.ElementTypes.end()) { - W.write(", "); - Path.push_back(TypeIndex::forTupleElement(I++)); - visit(*Iter++); - Path.pop_back(); - } - } - W.write(")"); - } - - void visitNilType(const TNil& Ty) override { - W.write("{}"); - } - - void visitAbsentType(const TAbsent& Ty) override { - W.write("Abs"); - } - - void visitPresentType(const TPresent& Ty) override { - Path.push_back(TypeIndex::forPresentType()); - visit(Ty.Ty); - Path.pop_back(); - } - - void visitFieldType(const TField& Ty) override { - W.write("{ "); - W.write(Ty.Name); - W.write(": "); - Path.push_back(TypeIndex::forFieldType()); - visit(Ty.Ty); - Path.pop_back(); - auto Ty2 = Ty.RestTy; - Path.push_back(TypeIndex::forFieldRest()); - std::size_t I = 1; - while (Ty2->isField()) { - auto Y = Ty2->asField(); - W.write("; "); - W.write(Y.Name); - W.write(": "); - Path.push_back(TypeIndex::forFieldType()); - visit(Y.Ty); - Path.pop_back(); - Ty2 = Y.RestTy; - Path.push_back(TypeIndex::forFieldRest()); - ++I; - } - if (Ty2->getKind() != TypeKind::Nil) { - W.write("; "); - visit(Ty2); - } - W.write(" }"); - for (auto K = 0; K < I; K++) { - Path.pop_back(); - } - } - - }; - - TypePrinter P { *this, Underline }; - P.visit(Ty); - - resetStyles(); -} - -void ConsolePrinter::writeType(std::size_t I) { - setForegroundColor(Color::Green); - write(I); - resetStyles(); -} - void ConsolePrinter::writeNode(const Node* N) { auto Range = N->getRange(); writeExcerpt(N->getSourceFile()->getTextFile(), Range, Range, Color::Red); @@ -703,19 +483,42 @@ void ConsolePrinter::writePrefix(const Diagnostic& D) { resetStyles(); } -void ConsolePrinter::writeTypeclassName(const ByteString& Name) { - setForegroundColor(Color::Magenta); - write(Name); - resetStyles(); -} - -void ConsolePrinter::writeTypeclassSignature(const TypeclassSignature& Sig) { - setForegroundColor(Color::Magenta); - write(Sig.Id); - for (auto TV: Sig.Params) { - write(" "); - write(describe(TV)); - } +void ConsolePrinter::writeType(Type* Ty) { + std::function visit = [&](auto Ty) { + switch (Ty->getKind()) { + case TypeKind::Var: + { + auto T = static_cast(Ty); + // FIXME + write("α"); + break; + } + case TypeKind::Con: + { + auto T = static_cast(Ty); + write(T->getName()); + break; + } + case TypeKind::Fun: + { + auto T = static_cast(Ty); + visit(T->getLeft()); + write(" -> "); + visit(T->getRight()); + break; + } + case TypeKind::App: + { + auto T = static_cast(Ty); + visit(T->getLeft()); + write(" "); + visit(T->getRight()); + break; + } + } + }; + setForegroundColor(Color::Green); + visit(Ty); resetStyles(); } @@ -799,11 +602,14 @@ void ConsolePrinter::writeDiagnostic(const Diagnostic& D) { return; } - case DiagnosticKind::UnificationError: + case DiagnosticKind::TypeMismatchError: { - auto& E = static_cast(D); - auto Left = E.OrigLeft->resolve(E.LeftPath); - auto Right = E.OrigRight->resolve(E.RightPath); + auto& E = static_cast(D); + // auto Left = E.OrigLeft->resolve(E.LeftPath); + // auto Right = E.OrigRight->resolve(E.RightPath); + auto Left = E.Left; + auto Right = E.Right; + auto S = E.getNode(); writePrefix(E); write("the types "); writeType(Left); @@ -815,7 +621,7 @@ void ConsolePrinter::writeDiagnostic(const Diagnostic& D) { write(" info: "); resetStyles(); write("due to an equality constraint on "); - write(describe(E.Source->getKind())); + write(describe(S->getKind())); write(":\n\n"); // write(" - left type "); // writeType(E.OrigLeft, E.LeftPath); @@ -823,7 +629,7 @@ void ConsolePrinter::writeDiagnostic(const Diagnostic& D) { // write(" - right type "); // writeType(E.OrigRight, E.RightPath); // write("\n\n"); - writeNode(E.Source); + writeNode(S); write("\n"); // if (E.Left != E.OrigLeft) { // setForegroundColor(Color::Yellow); @@ -850,87 +656,6 @@ void ConsolePrinter::writeDiagnostic(const Diagnostic& D) { return; } - case DiagnosticKind::TypeclassMissing: - { - auto& E = static_cast(D); - writePrefix(E); - write("the type class "); - writeTypeclassSignature(E.Sig); - write(" is missing from the declaration's type signature\n\n"); - writeNode(E.Decl); - write("\n\n"); - return; - } - - case DiagnosticKind::InstanceNotFound: - { - auto& E = static_cast(D); - writePrefix(E); - write("a type class instance "); - writeTypeclassName(E.TypeclassName); - write(" "); - writeType(E.Ty); - write(" was not found.\n\n"); - writeNode(E.Source); - write("\n"); - return; - } - - case DiagnosticKind::TupleIndexOutOfRange: - { - auto& E = static_cast(D); - writePrefix(E); - write("the index "); - writeType(E.I); - write(" is out of range for tuple "); - writeType(E.Tuple); - write("\n\n"); - writeNode(E.Source); - write("\n"); - return; - } - - case DiagnosticKind::InvalidTypeToTypeclass: - { - auto& E = static_cast(D); - writePrefix(E); - write("the type "); - writeType(E.Actual); - write(" was applied to type class names "); - bool First = true; - for (auto Class: E.Classes) { - if (First) First = false; - else write(", "); - writeTypeclassName(Class); - } - write(" but this is invalid\n\n"); - return; - } - - case DiagnosticKind::FieldNotFound: - { - auto& E = static_cast(D); - writePrefix(E); - write("the field '"); - write(E.Name); - write("' was required in one type but not found in another\n\n"); - writeNode(E.Source); - write("\n"); - return; - } - - case DiagnosticKind::NotATuple: - { - auto& E = static_cast(D); - writePrefix(E); - write("the type "); - writeType(E.Ty); - write(" is not a tuple.\n\n"); - writeNode(E.Source); - write("\n"); - return; - } - } ZEN_UNREACHABLE diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index e0ec04e0b..5b93ee339 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -2,10 +2,9 @@ // FIXME writeExcerpt does not work well with the last line in a file #include +#include #include -#include "zen/config.hpp" - #include "bolt/CST.hpp" #include "bolt/Type.hpp" #include "bolt/DiagnosticEngine.hpp" diff --git a/src/Evaluator.cc b/src/Evaluator.cc index d8914db69..04a18c85c 100644 --- a/src/Evaluator.cc +++ b/src/Evaluator.cc @@ -1,6 +1,4 @@ -#include "zen/range.hpp" - #include "bolt/CST.hpp" #include "bolt/Evaluator.hpp" @@ -62,8 +60,24 @@ Value Evaluator::apply(Value Op, std::vector Args) { { auto Fn = Op.getDeclaration(); Env NewEnv; - for (auto [Param, Arg]: zen::zip(Fn->getParams(), Args)) { - assignPattern(Param->Pattern, Arg, NewEnv); + auto Params= Fn->getParams(); + auto ParamIter = Params.begin(); + auto ParamsEnd = Params.end(); + auto ArgIter = Args.begin(); + auto ArgsEnd= Args.end(); + for (;;) { + if (ParamIter == ParamsEnd && ArgIter == ArgsEnd) { + break; + } + if (ParamIter == ParamsEnd) { + // TODO Make this a soft failure + ZEN_PANIC("Too much arguments supplied to function call."); + } + if (ArgIter == ArgsEnd) { + // TODO Make this a soft failure + ZEN_PANIC("Too much few arguments supplied to function call."); + } + assignPattern((*ParamIter)->Pattern, *ArgIter, NewEnv); } switch (Fn->getBody()->getKind()) { case NodeKind::LetExprBody: diff --git a/src/LLVMCodeGen.cc b/src/LLVMCodeGen.cc index 26cde5ef4..56955e871 100644 --- a/src/LLVMCodeGen.cc +++ b/src/LLVMCodeGen.cc @@ -2,12 +2,11 @@ #include #include -#include "llvm/IR/Value.h" - -#include "LLVMCodeGen.hpp" #include "bolt/CST.hpp" #include "bolt/CSTVisitor.hpp" +#include "LLVMCodeGen.hpp" + namespace bolt { LLVMCodeGen::LLVMCodeGen(llvm::LLVMContext* TheContext): diff --git a/src/LLVMCodeGen.hpp b/src/LLVMCodeGen.hpp index b5b86f45f..e66501ea9 100644 --- a/src/LLVMCodeGen.hpp +++ b/src/LLVMCodeGen.hpp @@ -1,7 +1,8 @@ #pragma once -#include "llvm/IR/Value.h" +#include + #include "llvm/IR/IRBuilder.h" namespace bolt { diff --git a/src/Parser.cc b/src/Parser.cc index 8de8939df..48a748020 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -4,6 +4,8 @@ #include #include +#include "zen/config.hpp" + #include "bolt/Common.hpp" #include "bolt/CST.hpp" #include "bolt/Scanner.hpp" diff --git a/src/Program.cc b/src/Program.cc new file mode 100644 index 000000000..573516eb1 --- /dev/null +++ b/src/Program.cc @@ -0,0 +1,3 @@ + +#include "bolt/Program.hpp" + diff --git a/src/Scanner.cc b/src/Scanner.cc index 844f57ef2..d6c790fda 100644 --- a/src/Scanner.cc +++ b/src/Scanner.cc @@ -1,8 +1,6 @@ #include -#include "zen/config.hpp" - #include "bolt/Common.hpp" #include "bolt/Text.hpp" #include "bolt/Integer.hpp" diff --git a/src/Type.cc b/src/Type.cc new file mode 100644 index 000000000..ab3ad7827 --- /dev/null +++ b/src/Type.cc @@ -0,0 +1,102 @@ + +#include "zen/config.hpp" + +#include "bolt/Type.hpp" + +namespace bolt { + +Type* Type::resolve(const TypePath& P) { + auto Ty = this; + for (auto& Index: P) { + switch (Index.Kind) { + case TypeIndexKind::AppOp: + Ty = static_cast(Ty)->getLeft(); + break; + case TypeIndexKind::AppArg: + Ty = static_cast(Ty)->getRight(); + break; + case TypeIndexKind::ArrowLeft: + Ty = static_cast(Ty)->getLeft(); + break; + case TypeIndexKind::ArrowRight: + Ty = static_cast(Ty)->getRight(); + break; + default: + ZEN_UNREACHABLE + } + } + return Ty; +} + +bool Type::operator==(const Type& Other) const { + if (Other.getKind() != TK) { + return false; + } + switch (TK) { + case TypeKind::App: + { + auto A1 = static_cast(*this); + auto A2 = static_cast(Other); + return *A1.getLeft() == *A2.getLeft() && *A1.getRight() == *A2.getRight(); + } + case TypeKind::Var: + return this == &Other; + case TypeKind::Fun: + { + auto F1 = static_cast(*this); + auto F2 = static_cast(Other); + return *F1.getLeft() == *F2.getLeft() && *F1.getRight() == *F2.getRight(); + } + case TypeKind::Con: + { + auto C1 = static_cast(*this); + auto C2 = static_cast(Other); + return C1.getName() == C2.getName(); + } + } +} + +std::string Type::toString() const { + switch (TK) { + case TypeKind::App: + { + auto A = static_cast(this); + return A->getLeft()->toString() + " " + A->getRight()->toString(); + } + case TypeKind::Con: + { + auto C = static_cast(this); + return std::string(C->getName()); + } + case TypeKind::Fun: + { + auto F = static_cast(this); + return F->getLeft()->toString() + " -> " + F->getRight()->toString(); + } + case TypeKind::Var: + return "α"; + } +} + +TVar* Type::asVar() { + return static_cast(this); +} + +void TypeVisitor::visit(Type* Ty) { + switch (Ty->getKind()) { + case TypeKind::App: + visitApp(static_cast(Ty)); + break; + case TypeKind::Con: + visitCon(static_cast(Ty)); + break; + case TypeKind::Fun: + visitFun(static_cast(Ty)); + break; + case TypeKind::Var: + visitVar(static_cast(Ty)); + break; + } +} + +} diff --git a/src/Types.cc b/src/Types.cc deleted file mode 100644 index 2ff1f318f..000000000 --- a/src/Types.cc +++ /dev/null @@ -1,336 +0,0 @@ - -#include "bolt/Type.hpp" -#include -#include -#include - -#include "zen/range.hpp" - -namespace bolt { - -bool TypeclassSignature::operator<(const TypeclassSignature& Other) const { - if (Id < Other.Id) { - return true; - } - ZEN_ASSERT(Params.size() == 1); - ZEN_ASSERT(Other.Params.size() == 1); - return Params[0]->asCon().Id < Other.Params[0]->asCon().Id; -} - -bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { - ZEN_ASSERT(Params.size() == 1); - ZEN_ASSERT(Other.Params.size() == 1); - return Id == Other.Id && Params[0]->asCon().Id == Other.Params[0]->asCon().Id; -} - -bool TypeIndex::operator==(const TypeIndex& Other) const noexcept { - if (Kind != Other.Kind) { - return false; - } - switch (Kind) { - case TypeIndexKind::ArrowParamType: - case TypeIndexKind::TupleElement: - return I == Other.I; - default: - return true; - } -} - -bool TCon::operator==(const TCon& Other) const { - return Id == Other.Id; -} - -bool TApp::operator==(const TApp& Other) const { - return *Op == *Other.Op && *Arg == *Other.Arg; -} - -bool TVar::operator==(const TVar& Other) const { - return Id == Other.Id; -} - -bool TArrow::operator==(const TArrow& Other) const { - return *ParamType == *Other.ParamType - && *ReturnType == *Other.ReturnType; -} - -bool TTuple::operator==(const TTuple& Other) const { - for (auto [T1, T2]: zen::zip(ElementTypes, Other.ElementTypes)) { - if (*T1 != *T2) { - return false; - } - } - return true; -} - -bool TNil::operator==(const TNil& Other) const { - return true; -} - -bool TField::operator==(const TField& Other) const { - return Name == Other.Name && *Ty == *Other.Ty && *RestTy == *Other.RestTy; -} - -bool TAbsent::operator==(const TAbsent& Other) const { - return true; -} - -bool TPresent::operator==(const TPresent& Other) const { - return *Ty == *Other.Ty; -} - -bool Type::operator==(const Type& Other) const { - if (Kind != Other.Kind) { - return false; - } - switch (Kind) { - case TypeKind::Var: - return Var == Other.Var; - case TypeKind::Con: - return Con == Other.Con; - case TypeKind::Present: - return Present == Other.Present; - case TypeKind::Absent: - return Absent == Other.Absent; - case TypeKind::Arrow: - return Arrow == Other.Arrow; - case TypeKind::Field: - return Field == Other.Field; - case TypeKind::Nil: - return Nil == Other.Nil; - case TypeKind::Tuple: - return Tuple == Other.Tuple; - case TypeKind::App: - return App == Other.App; - } - ZEN_UNREACHABLE -} - -void Type::visitEachChild(std::function Proc) { - switch (Kind) { - case TypeKind::Var: - case TypeKind::Absent: - case TypeKind::Nil: - case TypeKind::Con: - break; - case TypeKind::Arrow: - { - Proc(Arrow.ParamType); - Proc(Arrow.ReturnType); - break; - } - case TypeKind::Tuple: - { - for (auto I = 0; I < Tuple.ElementTypes.size(); ++I) { - Proc(Tuple.ElementTypes[I]); - } - break; - } - case TypeKind::App: - { - Proc(App.Op); - Proc(App.Arg); - break; - } - case TypeKind::Field: - { - Proc(Field.Ty); - Proc(Field.RestTy); - break; - } - case TypeKind::Present: - { - Proc(Present.Ty); - break; - } - } -} - -Type* Type::rewrite(std::function Fn, bool Recursive) { - auto Ty2 = Fn(this); - if (this != Ty2) { - if (Recursive) { - return Ty2->rewrite(Fn, Recursive); - } - return Ty2; - } - switch (Kind) { - case TypeKind::Var: - return Ty2; - case TypeKind::Arrow: - { - auto Arrow = Ty2->asArrow(); - bool Changed = false; - Type* NewParamType = Arrow.ParamType->rewrite(Fn, Recursive); - if (NewParamType != Arrow.ParamType) { - Changed = true; - } - auto NewRetTy = Arrow.ReturnType->rewrite(Fn, Recursive); - if (NewRetTy != Arrow.ReturnType) { - Changed = true; - } - return Changed ? new Type(TArrow(NewParamType, NewRetTy)) : Ty2; - } - case TypeKind::Con: - return Ty2; - case TypeKind::App: - { - auto App = Ty2->asApp(); - auto NewOp = App.Op->rewrite(Fn, Recursive); - auto NewArg = App.Arg->rewrite(Fn, Recursive); - if (NewOp == App.Op && NewArg == App.Arg) { - return Ty2; - } - return new Type(TApp(NewOp, NewArg)); - } - case TypeKind::Tuple: - { - auto Tuple = Ty2->asTuple(); - bool Changed = false; - std::vector NewElementTypes; - for (auto Ty: Tuple.ElementTypes) { - auto NewElementType = Ty->rewrite(Fn, Recursive); - if (NewElementType != Ty) { - Changed = true; - } - NewElementTypes.push_back(NewElementType); - } - return Changed ? new Type(TTuple(NewElementTypes)) : Ty2; - } - case TypeKind::Nil: - return Ty2; - case TypeKind::Absent: - return Ty2; - case TypeKind::Field: - { - auto Field = Ty2->asField(); - bool Changed = false; - auto NewTy = Field.Ty->rewrite(Fn, Recursive); - if (NewTy != Field.Ty) { - Changed = true; - } - auto NewRestTy = Field.RestTy->rewrite(Fn, Recursive); - if (NewRestTy != Field.RestTy) { - Changed = true; - } - return Changed ? new Type(TField(Field.Name, NewTy, NewRestTy)) : Ty2; - } - case TypeKind::Present: - { - auto Present = Ty2->asPresent(); - auto NewTy = Present.Ty->rewrite(Fn, Recursive); - if (NewTy == Present.Ty) { - return Ty2; - } - return new Type(TPresent(NewTy)); - } - } - ZEN_UNREACHABLE -} - -Type* Type::substitute(const TVSub &Sub) { - return rewrite([&](auto Ty) { - if (Ty->isVar()) { - auto Match = Sub.find(Ty); - return Match != Sub.end() ? Match->second->substitute(Sub) : Ty; - } - return Ty; - }, false); -} - -Type* Type::resolve(const TypeIndex& Index) const noexcept { - switch (Index.Kind) { - case TypeIndexKind::PresentType: - return this->asPresent().Ty; - case TypeIndexKind::AppOpType: - return this->asApp().Op; - case TypeIndexKind::AppArgType: - return this->asApp().Arg; - case TypeIndexKind::TupleElement: - return this->asTuple().ElementTypes[Index.I]; - case TypeIndexKind::ArrowParamType: - return this->asArrow().ParamType; - case TypeIndexKind::ArrowReturnType: - return this->asArrow().ReturnType; - case TypeIndexKind::FieldType: - return this->asField().Ty; - case TypeIndexKind::FieldRestType: - return this->asField().RestTy; - case TypeIndexKind::End: - ZEN_UNREACHABLE - } - ZEN_UNREACHABLE -} - -TVSet Type::getTypeVars() { - TVSet Out; - std::function visit = [&](Type* Ty) { - if (Ty->isVar()) { - Out.emplace(Ty); - return; - } - Ty->visitEachChild(visit); - }; - visit(this); - return Out; -} - -TypeIterator Type::begin() { - return TypeIterator { this, getStartIndex() }; -} - -TypeIterator Type::end() { - return TypeIterator { this, getEndIndex() }; -} - -TypeIndex Type::getStartIndex() const { - switch (Kind) { - case TypeKind::Arrow: - return TypeIndex::forArrowParamType(); - case TypeKind::Tuple: - { - if (asTuple().ElementTypes.empty()) { - return TypeIndex(TypeIndexKind::End); - } - return TypeIndex::forTupleElement(0); - } - case TypeKind::Field: - return TypeIndex::forFieldType(); - default: - return TypeIndex(TypeIndexKind::End); - } -} - -TypeIndex Type::getEndIndex() const { - return TypeIndex(TypeIndexKind::End); -} - -bool Type::hasTypeVar(Type* TV) const { - switch (Kind) { - case TypeKind::Var: - return Var.Id == TV->asVar().Id; - case TypeKind::Con: - case TypeKind::Absent: - case TypeKind::Nil: - return false; - case TypeKind::App: - return App.Op->hasTypeVar(TV) || App.Arg->hasTypeVar(TV); - case TypeKind::Tuple: - for (auto Ty: Tuple.ElementTypes) { - if (Ty->hasTypeVar(TV)) { - return true; - } - } - return false; - case TypeKind::Field: - return Field.Ty->hasTypeVar(TV) || Field.RestTy->hasTypeVar(TV); - case TypeKind::Arrow: - return Arrow.ParamType->hasTypeVar(TV) || Arrow.ReturnType->hasTypeVar(TV); - case TypeKind::Present: - return Present.Ty->hasTypeVar(TV); - } - ZEN_UNREACHABLE -} - -} - - diff --git a/src/main.cc b/src/main.cc index 876dedf72..2020a1cfe 100644 --- a/src/main.cc +++ b/src/main.cc @@ -6,7 +6,6 @@ #include #include -#include "zen/config.hpp" #include "zen/po.hpp" #include "bolt/CST.hpp" @@ -113,11 +112,11 @@ int main(int Argc, const char* Argv[]) { void visitExpression(Expression* N) { for (auto A: N->Annotations) { if (A->getKind() == NodeKind::TypeAssertAnnotation) { - auto Left = C.getType(N); + auto Left = C.getTypeOfNode(N); auto Right = static_cast(A)->getTypeExpression()->getType(); - std::cerr << "verify " << describe(Left) << " == " << describe(Right) << std::endl; + std::cerr << "verify " << Left->toString() << " == " << Right->toString() << std::endl; if (*Left != *Right) { - DE.add(Left, Right, TypePath(), TypePath(), A); + DE.add(Left, Right, A); } } } diff --git a/x.py b/x.py index 4fd8a9b92..52e467f86 100755 --- a/x.py +++ b/x.py @@ -167,7 +167,8 @@ def build_bolt(c_path: str | None = None, cxx_path: str | None = None) -> None: 'CMAKE_BUILD_TYPE': 'Debug', 'BOLT_ENABLE_TESTS': True, 'ZEN_ENABLE_TESTS': False, - 'LLVM_CONFIG': str(llvm_config_path) + #'LLVM_CONFIG': str(llvm_config_path), + 'LLVM_TARGETS_TO_BUILD': 'X86', } if c_path is not None: defines['CMAKE_C_COMPILER'] = c_path @@ -196,13 +197,13 @@ c_path = None cxx_path = None if os.name == 'posix': - clang_c_path = shutil.which('clangj') + clang_c_path = shutil.which('clang') clang_cxx_path = shutil.which('clang++') if clang_c_path is not None and clang_cxx_path is not None and (force == NONE or force == CLANG): c_path = clang_c_path cxx_path = clang_cxx_path else: - for version in [ '18' ]: + for version in [ '18', '19' ]: clang_c_path = shutil.which(f'clang-{version}') clang_cxx_path = shutil.which(f'clang++-{version}') if clang_c_path is not None and clang_cxx_path is not None and (force == NONE or force == CLANG):