From b1d685bdaf27d8d5434b0b96335cf6b032425abf Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sun, 2 Mar 2025 00:19:19 +0100 Subject: [PATCH] Add again generalization in Checker.cc Constraint instantiation is still missing. --- CMakeLists.txt | 1 + include/bolt/Checker.hpp | 6 +++- include/bolt/Constraint.hpp | 6 ++++ include/bolt/Type.hpp | 19 ++++++---- src/Checker.cc | 71 ++++++++++++++++++++++++++++--------- src/Constraint.cc | 17 +++++++++ src/Type.cc | 16 ++++++++- 7 files changed, 111 insertions(+), 25 deletions(-) create mode 100644 src/Constraint.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 90f6b0761..fe0da56ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,7 @@ add_library( src/Scanner.cc src/Parser.cc src/Type.cc + src/Constraint.cc src/Checker.cc src/Evaluator.cc src/Scope.cc diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index 2a5187a82..340347376 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -31,6 +31,8 @@ public: bool hasVar(TVar* TV) const; + void dump() const; + TypeScheme* lookup(ByteString Name, SymbolKind Kind); }; @@ -46,6 +48,8 @@ class Checker { Type* StringType; Type* UnitType; + unsigned NextVarId = 0; + public: Checker(DiagnosticEngine& DE); @@ -67,7 +71,7 @@ public: } TVar* createTVar() { - return new TVar(); + return new TVar("a" + std::to_string(NextVarId++)); } Type* instantiate(TypeScheme* Scm); diff --git a/include/bolt/Constraint.hpp b/include/bolt/Constraint.hpp index 848f23a01..9bf98e7e2 100644 --- a/include/bolt/Constraint.hpp +++ b/include/bolt/Constraint.hpp @@ -5,6 +5,8 @@ namespace bolt { +class Node; + enum class ConstraintKind { TypesEqual, }; @@ -24,6 +26,8 @@ public: return Kind; } + std::string toString() const; + }; class CTypesEqual : public Constraint { @@ -49,6 +53,8 @@ public: return Origin; } + std::string toString() const; + }; } diff --git a/include/bolt/Type.hpp b/include/bolt/Type.hpp index dd28bd034..e50fe863b 100644 --- a/include/bolt/Type.hpp +++ b/include/bolt/Type.hpp @@ -12,6 +12,8 @@ namespace bolt { +class Constraint; + enum class TypeIndexKind { AppOp, AppArg, @@ -123,8 +125,10 @@ class TVar : public Type { public: - TVar(): - Type(TypeKind::Var) {} + std::string Name; + + TVar(std::string Name): + Type(TypeKind::Var), Name(Name) {} void set(Type* Ty) { auto Root = find(); @@ -136,11 +140,11 @@ public: Type* find() const override { TVar* Curr = const_cast(this); for (;;) { - auto Keep = Curr->Parent; - if (Keep == Curr || !Keep->isVar()) { - return Keep; + auto Parent = Curr->Parent; + if (Parent == Curr || !Parent->isVar()) { + return Parent; } - auto Keep2 = static_cast(Keep); + auto Keep2 = static_cast(Parent); Curr->Parent = Keep2->Parent; Curr = Keep2; } @@ -214,12 +218,15 @@ public: struct TypeScheme { std::unordered_set Unbound; + std::vector Constraints; Type* Ty; Type* getType() const { return Ty; } + std::string toString() const; + }; class TypeVisitor { diff --git a/src/Checker.cc b/src/Checker.cc index a3f2aead0..885f62d16 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -1,15 +1,14 @@ +#include +#include +#include -#include "bolt/CSTVisitor.hpp" #include "zen/graph.hpp" #include "bolt/ByteString.hpp" #include "bolt/CST.hpp" +#include "bolt/CSTVisitor.hpp" #include "bolt/Type.hpp" #include "bolt/Diagnostics.hpp" -#include -#include -#include -#include #include "bolt/Checker.hpp" namespace bolt { @@ -33,11 +32,30 @@ TypeScheme* TypeEnv::lookup(ByteString Name, SymbolKind Kind) { } void TypeEnv::add(ByteString Name, TypeScheme* Scm, SymbolKind Kind) { - Mapping.emplace(std::make_tuple(Name, Kind), Scm); + Mapping[std::make_tuple(Name, Kind)] = Scm; } void TypeEnv::add(ByteString Name, Type* Ty, SymbolKind Kind) { - add(Name, new TypeScheme { {}, Ty }, Kind); + add(Name, new TypeScheme { {}, {}, Ty }, Kind); +} + +void TypeEnv::dump() const { + for (auto [Tuple, Scm]: Mapping) { + auto Name = std::get<0>(Tuple); + auto Kind = std::get<1>(Tuple); + switch (Kind) { + case SymbolKind::Var: + std::cerr << "let " << Name << " : " << Scm->toString() << "\n"; + break; + case SymbolKind::Type: + std::cerr << "type " << Name << " = " << Scm->toString() << "\n"; + break; + case SymbolKind::Class: + ZEN_UNREACHABLE // TODO + case SymbolKind::Constructor: + ZEN_UNREACHABLE // TODO + } + } } using TVSub = std::unordered_map; @@ -94,6 +112,7 @@ Type* Checker::instantiate(TypeScheme* Scm) { auto Fresh = createTVar(); Sub[TV] = Fresh; } + // TODO instantiate constraints return substituteType(Scm->getType(), Sub); } @@ -467,26 +486,45 @@ bool TypeEnv::hasVar(TVar* TV) const { return false; } -auto getUnbound(const TypeEnv& Env, Type* Ty) { +static void addUnbound(Type* Ty, const TypeEnv& Env, std::unordered_set& Vars) { struct Visitor : public TypeVisitor { const TypeEnv& Env; - Visitor(const TypeEnv& Env): - Env(Env) {} - std::vector Out; + std::unordered_set& Vars; + Visitor(const TypeEnv& Env, std::unordered_set& Vars): + Env(Env), Vars(Vars) {} void visitVar(TVar* TV) { auto Solved = TV->find(); if (isa(Solved)) { auto Var = static_cast(Solved); if (!Env.hasVar(Var)) { - Out.push_back(Var); + Vars.emplace(Var); } } else { visit(Solved); } } - } V { Env }; + } V { Env, Vars }; V.visit(Ty); - return V.Out; +} + +static void addUnbound(const Constraint& C, const TypeEnv& Env, std::unordered_set& Vars) { + switch (C.getKind()) { + case ConstraintKind::TypesEqual: + { + auto TE = static_cast(C); + addUnbound(TE.getLeft(), Env, Vars); + addUnbound(TE.getRight(), Env, Vars); + break; + } + } +} + +static TypeScheme* generalize(const TypeEnv& Env, const ConstraintSet& Constraints, Type* Ty) { + std::unordered_set Vars; + for (const auto C: Constraints) { + addUnbound(*C, Env, Vars); + } + return new TypeScheme { Vars, Constraints, Ty }; } ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector& Elements, Type* RetTy) { @@ -564,10 +602,9 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector& Elements, Typ for (auto N: Mutual) { if (isa(N)) { auto Func = static_cast(N); - auto Unbound = getUnbound(Env, Func->getType()); Env.add( Func->getNameAsString(), - new TypeScheme { { Unbound.begin(), Unbound.end() }, Func->getType()->find() }, + generalize(Env, Out, Func->getType()->find()), SymbolKind::Var ); } @@ -737,7 +774,7 @@ void Checker::run(SourceFile* SF) { Env.add("not", new TFun(Bool, Bool), SymbolKind::Var); Env.add("+", new TFun(Int, new TFun(Int, Int)), SymbolKind::Var); Env.add("-", new TFun(Int, new TFun(Int, Int)), SymbolKind::Var); - Env.add("$", new TypeScheme({ A, B }, new TFun(new TFun(A, B), new TFun(A, B))), SymbolKind::Var); + Env.add("$", new TypeScheme({ A, B }, {}, new TFun(new TFun(A, B), new TFun(A, B))), SymbolKind::Var); auto Out = inferSourceFile(Env, SF); solve(Out); } diff --git a/src/Constraint.cc b/src/Constraint.cc new file mode 100644 index 000000000..c64143c60 --- /dev/null +++ b/src/Constraint.cc @@ -0,0 +1,17 @@ + +#include "bolt/Constraint.hpp" + +namespace bolt { + +std::string Constraint::toString() const { + switch (Kind) { + case ConstraintKind::TypesEqual: + return static_cast(this)->toString(); + } +} + +std::string CTypesEqual::toString() const { + return A->toString() + " ~ " + B->toString(); +} + +} diff --git a/src/Type.cc b/src/Type.cc index ab3ad7827..4ad18e99b 100644 --- a/src/Type.cc +++ b/src/Type.cc @@ -1,5 +1,6 @@ #include "zen/config.hpp" +#include #include "bolt/Type.hpp" @@ -74,7 +75,7 @@ std::string Type::toString() const { return F->getLeft()->toString() + " -> " + F->getRight()->toString(); } case TypeKind::Var: - return "α"; + return static_cast(this)->Name; } } @@ -99,4 +100,17 @@ void TypeVisitor::visit(Type* Ty) { } } +std::string TypeScheme::toString() const { + std::ostringstream Out; + if (!Unbound.empty()) { + Out << "forall"; + for (auto TV: Unbound) { + Out << " " << TV->toString(); + } + Out << ". "; + } + Out << Ty->toString(); + return Out.str(); +} + }