From 7ac3c39164bcf2a368cdaba5e0997b75e635a57c Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sun, 21 Jan 2024 03:42:25 +0100 Subject: [PATCH] Make TypeEnv sort variables on whether they are a var or a function Fixes RecordDeclaration and VariantDeclaration not working correctly --- bootstrap/cxx/include/bolt/Checker.hpp | 63 ++++++++---- bootstrap/cxx/src/Checker.cc | 97 ++++++++++--------- .../cxx/test/checker/test_equality_tapp.bolt | 10 +- 3 files changed, 101 insertions(+), 69 deletions(-) diff --git a/bootstrap/cxx/include/bolt/Checker.hpp b/bootstrap/cxx/include/bolt/Checker.hpp index 729d546b9..4d8f96e1b 100644 --- a/bootstrap/cxx/include/bolt/Checker.hpp +++ b/bootstrap/cxx/include/bolt/Checker.hpp @@ -1,21 +1,28 @@ #pragma once +#include +#include +#include +#include + +#include "zen/tuple_hash.hpp" + #include "bolt/ByteString.hpp" #include "bolt/Common.hpp" #include "bolt/CST.hpp" #include "bolt/Type.hpp" #include "bolt/Support/Graph.hpp" -#include -#include -#include -#include - namespace bolt { std::string describe(const Type* Ty); // For debugging only + enum class SymKind { + Type, + Var, + }; + class DiagnosticEngine; class Constraint; @@ -70,7 +77,35 @@ namespace bolt { }; - using TypeEnv = std::unordered_map; + class TypeEnv { + + std::unordered_map, Scheme*> Mapping; + + public: + + Scheme* lookup(ByteString Name, SymKind Kind) { + auto Key = std::make_tuple(Name, Kind); + auto Match = Mapping.find(Key); + if (Match == Mapping.end()) { + return nullptr; + } + return Match->second; + } + + void add(ByteString Name, Scheme* Scm, SymKind Kind) { + auto Key = std::make_tuple(Name, Kind); + ZEN_ASSERT(!Mapping.count(Key)) + // auto F = static_cast(Scm); + // std::cerr << Name << " : forall "; + // for (auto TV: *F->TVs) { + // std::cerr << describe(TV) << " "; + // } + // std::cerr << ". " << describe(F->Type) << "\n"; + Mapping.emplace(Key, Scm); + } + + }; + enum class ConstraintKind { Equal, @@ -158,16 +193,6 @@ namespace bolt { TypeEnv Env; - void add(ByteString Name, Scheme* Scm) { - // auto F = static_cast(Scm); - // std::cerr << Name << " : forall "; - // for (auto TV: *F->TVs) { - // std::cerr << describe(TV) << " "; - // } - // std::cerr << ". " << describe(F->Type) << "\n"; - Env.emplace(Name, Scm); - } - Type* ReturnType = nullptr; InferContext* Parent = nullptr; @@ -240,7 +265,7 @@ namespace bolt { /// Environment manipulation - Scheme* lookup(ByteString Name); + Scheme* lookup(ByteString Name, SymKind Kind); /** * Looks up a type/variable and ensures that it is a monomorphic type. @@ -254,9 +279,9 @@ namespace bolt { * \returns If the type/variable could not be found `nullptr` is returned. * Otherwise, a [Type] is returned. */ - Type* lookupMono(ByteString Name); + Type* lookupMono(ByteString Name, SymKind Kind); - void addBinding(ByteString Name, Scheme* Scm); + void addBinding(ByteString Name, Scheme* Scm, SymKind Kind); /// Constraint solving diff --git a/bootstrap/cxx/src/Checker.cc b/bootstrap/cxx/src/Checker.cc index 2e0578405..029299884 100644 --- a/bootstrap/cxx/src/Checker.cc +++ b/bootstrap/cxx/src/Checker.cc @@ -55,12 +55,12 @@ namespace bolt { UnitType = new Type(TTuple({})); } - Scheme* Checker::lookup(ByteString Name) { + Scheme* Checker::lookup(ByteString Name, SymKind Kind) { auto Curr = &getContext(); for (;;) { - auto Match = Curr->Env.find(Name); - if (Match != Curr->Env.end()) { - return Match->second; + auto Match = Curr->Env.lookup(Name, Kind); + if (Match != nullptr) { + return Match; } Curr = Curr->Parent; if (!Curr) { @@ -70,8 +70,8 @@ namespace bolt { return nullptr; } - Type* Checker::lookupMono(ByteString Name) { - auto Scm = lookup(Name); + Type* Checker::lookupMono(ByteString Name, SymKind Kind) { + auto Scm = lookup(Name, Kind); if (Scm == nullptr) { return nullptr; } @@ -80,8 +80,8 @@ namespace bolt { return F->Type; } - void Checker::addBinding(ByteString Name, Scheme* Scm) { - getContext().add(Name, Scm); + void Checker::addBinding(ByteString Name, Scheme* Scm, SymKind Kind) { + getContext().Env.add(Name, Scm, Kind); } Type* Checker::getReturnType() { @@ -296,29 +296,33 @@ namespace bolt { Type* Ty = createConType(Decl->Name->getCanonicalText()); + // Build the type that is actually returned by constructor functions + auto RetTy = Ty; + for (auto Var: Vars) { + RetTy = new Type(TApp(RetTy, Var)); + } + // Must be added early so we can create recursive types - Decl->Ctx->Parent->add(Decl->Name->getCanonicalText(), new Forall(Ty)); + Decl->Ctx->Parent->Env.add(Decl->Name->getCanonicalText(), new Forall(Ty), SymKind::Type); for (auto Member: Decl->Members) { switch (Member->getKind()) { case NodeKind::TupleVariantDeclarationMember: { auto TupleMember = static_cast(Member); - auto RetTy = Ty; - for (auto Var: Vars) { - RetTy = new Type(TApp(RetTy, Var)); - } std::vector ParamTypes; for (auto Element: TupleMember->Elements) { + // inferTypeExpression will look up any TVars that were part of the signature of Decl ParamTypes.push_back(inferTypeExpression(Element)); } - Decl->Ctx->Parent->add( + Decl->Ctx->Parent->Env.add( TupleMember->Name->getCanonicalText(), new Forall( Decl->Ctx->TVs, Decl->Ctx->Constraints, Type::buildArrow(ParamTypes, RetTy) - ) + ), + SymKind::Var ); break; } @@ -353,7 +357,12 @@ namespace bolt { auto Ty = createConType(Name); // Must be added early so we can create recursive types - Decl->Ctx->Parent->add(Name, new Forall(Ty)); + Decl->Ctx->Parent->Env.add(Name, new Forall(Ty), SymKind::Type); + + Type* RetTy = Ty; + for (auto TV: Vars) { + RetTy = new Type(TApp(RetTy, TV)); + } // Corresponds to the logic of one branch of a VariantDeclarationMember Type* FieldsTy = new Type(TNil()); @@ -366,18 +375,16 @@ namespace bolt { ) ); } - Type* RetTy = Ty; - for (auto TV: Vars) { - RetTy = new Type(TApp(RetTy, TV)); - } - Decl->Ctx->Parent->add( + Decl->Ctx->Parent->Env.add( Name, new Forall( Decl->Ctx->TVs, Decl->Ctx->Constraints, new Type(TArrow(FieldsTy, RetTy)) - ) + ), + SymKind::Var ); + popContext(); break; @@ -463,7 +470,7 @@ namespace bolt { auto Name = TE->Name->getCanonicalText(); auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); TV->asVar().Context.emplace(Id); - Ctx->add(Name, new Forall(TV)); + Ctx->Env.add(Name, new Forall(TV), SymKind::Type); Out.push_back(TV); } return Out; @@ -553,7 +560,7 @@ namespace bolt { } if (!Let->isInstance()) { - Let->Ctx->Parent->add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty)); + Let->Ctx->Parent->Env.add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty), SymKind::Var); } } @@ -808,7 +815,7 @@ namespace bolt { case NodeKind::ReferenceTypeExpression: { auto RefTE = static_cast(N); - auto Scm = lookup(RefTE->Name->getCanonicalText()); + auto Scm = lookup(RefTE->Name->getCanonicalText(), SymKind::Type); Type* Ty; if (Scm == nullptr) { DE.add(RefTE->Name->getCanonicalText(), RefTE->Name); @@ -834,13 +841,13 @@ namespace bolt { case NodeKind::VarTypeExpression: { auto VarTE = static_cast(N); - auto Ty = lookupMono(VarTE->Name->getCanonicalText()); + auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type); if (Ty == nullptr) { if (IsPoly && Config.typeVarsRequireForall()) { DE.add(VarTE->Name->getCanonicalText(), VarTE->Name); } Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar(); - addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); + addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type); } ZEN_ASSERT(Ty->isVar()); N->setType(Ty); @@ -974,7 +981,7 @@ namespace bolt { auto Ref = static_cast(X); ZEN_ASSERT(Ref->ModulePath.empty()); if (Ref->Name->is()) { - auto Scm = lookup(Ref->Name->getCanonicalText()); + auto Scm = lookup(Ref->Name->getCanonicalText(), SymKind::Var); if (!Scm) { DE.add(Ref->Name->getCanonicalText(), Ref->Name); Ty = createTypeVar(); @@ -999,7 +1006,7 @@ namespace bolt { infer(Let); } } - auto Scm = lookup(Ref->Name->getCanonicalText()); + auto Scm = lookup(Ref->Name->getCanonicalText(), SymKind::Var); ZEN_ASSERT(Scm); Ty = instantiate(Scm, X); break; @@ -1021,7 +1028,7 @@ namespace bolt { case NodeKind::InfixExpression: { auto Infix = static_cast(X); - auto Scm = lookup(Infix->Operator->getText()); + auto Scm = lookup(Infix->Operator->getText(), SymKind::Var); if (Scm == nullptr) { DE.add(Infix->Operator->getText(), Infix->Operator); Ty = createTypeVar(); @@ -1102,14 +1109,14 @@ namespace bolt { { auto P = static_cast(Pattern); auto Ty = createTypeVar(); - addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty)); + addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty), SymKind::Var); return Ty; } case NodeKind::NamedPattern: { auto P = static_cast(Pattern); - auto Scm = lookup(P->Name->getCanonicalText()); + auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var); std::vector ParamTypes; for (auto P2: P->Patterns) { ParamTypes.push_back(inferPattern(P2, Constraints, TVs)); @@ -1167,10 +1174,10 @@ namespace bolt { Type* Ty; switch (L->getKind()) { case NodeKind::IntegerLiteral: - Ty = lookupMono("Int"); + Ty = lookupMono("Int", SymKind::Type); break; case NodeKind::StringLiteral: - Ty = lookupMono("String"); + Ty = lookupMono("String", SymKind::Type); break; default: ZEN_UNREACHABLE @@ -1235,18 +1242,18 @@ namespace bolt { void Checker::check(SourceFile *SF) { initialize(SF); setContext(SF->Ctx); - addBinding("String", new Forall(StringType)); - addBinding("Int", new Forall(IntType)); - addBinding("Bool", new Forall(BoolType)); - addBinding("List", new Forall(ListType)); - addBinding("True", new Forall(BoolType)); - addBinding("False", new Forall(BoolType)); + addBinding("String", new Forall(StringType), SymKind::Type); + addBinding("Int", new Forall(IntType), SymKind::Type); + addBinding("Bool", new Forall(BoolType), SymKind::Type); + addBinding("List", new Forall(ListType), SymKind::Type); + addBinding("True", new Forall(BoolType), SymKind::Var); + addBinding("False", new Forall(BoolType), SymKind::Var); auto A = createTypeVar(); - addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType))); - addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); - addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); - addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); - addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); + addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)), SymKind::Var); + addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); + addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); + addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); + addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); populate(SF); forwardDeclare(SF); auto SCCs = RefGraph.strongconnect(); diff --git a/bootstrap/cxx/test/checker/test_equality_tapp.bolt b/bootstrap/cxx/test/checker/test_equality_tapp.bolt index d7611b84c..cc296447b 100644 --- a/bootstrap/cxx/test/checker/test_equality_tapp.bolt +++ b/bootstrap/cxx/test/checker/test_equality_tapp.bolt @@ -1,10 +1,10 @@ -enum List a. +enum MyList a. Nil - Pair a (List a) + Pair a (MyList a) -let x : List Int +let x : MyList Int @expect_diagnostic 2010 -let y : List Bool = x -let z : List String +let y : MyList Bool = x +let z : MyList String