Make TypeEnv sort variables on whether they are a var or a function

Fixes RecordDeclaration and VariantDeclaration not working correctly
This commit is contained in:
Sam Vervaeck 2024-01-21 03:42:25 +01:00
parent 7b58a6c51f
commit 7ac3c39164
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
3 changed files with 101 additions and 69 deletions

View file

@ -1,21 +1,28 @@
#pragma once #pragma once
#include <cstdlib>
#include <unordered_map>
#include <vector>
#include <deque>
#include "zen/tuple_hash.hpp"
#include "bolt/ByteString.hpp" #include "bolt/ByteString.hpp"
#include "bolt/Common.hpp" #include "bolt/Common.hpp"
#include "bolt/CST.hpp" #include "bolt/CST.hpp"
#include "bolt/Type.hpp" #include "bolt/Type.hpp"
#include "bolt/Support/Graph.hpp" #include "bolt/Support/Graph.hpp"
#include <cstdlib>
#include <unordered_map>
#include <vector>
#include <deque>
namespace bolt { namespace bolt {
std::string describe(const Type* Ty); // For debugging only std::string describe(const Type* Ty); // For debugging only
enum class SymKind {
Type,
Var,
};
class DiagnosticEngine; class DiagnosticEngine;
class Constraint; class Constraint;
@ -70,7 +77,35 @@ namespace bolt {
}; };
using TypeEnv = std::unordered_map<ByteString, Scheme*>; class TypeEnv {
std::unordered_map<std::tuple<ByteString, SymKind>, Scheme*> Mapping;
public:
Scheme* lookup(ByteString Name, SymKind Kind) {
auto Key = std::make_tuple(Name, Kind);
auto Match = Mapping.find(Key);
if (Match == Mapping.end()) {
return nullptr;
}
return Match->second;
}
void add(ByteString Name, Scheme* Scm, SymKind Kind) {
auto Key = std::make_tuple(Name, Kind);
ZEN_ASSERT(!Mapping.count(Key))
// auto F = static_cast<Forall*>(Scm);
// std::cerr << Name << " : forall ";
// for (auto TV: *F->TVs) {
// std::cerr << describe(TV) << " ";
// }
// std::cerr << ". " << describe(F->Type) << "\n";
Mapping.emplace(Key, Scm);
}
};
enum class ConstraintKind { enum class ConstraintKind {
Equal, Equal,
@ -158,16 +193,6 @@ namespace bolt {
TypeEnv Env; TypeEnv Env;
void add(ByteString Name, Scheme* Scm) {
// auto F = static_cast<Forall*>(Scm);
// std::cerr << Name << " : forall ";
// for (auto TV: *F->TVs) {
// std::cerr << describe(TV) << " ";
// }
// std::cerr << ". " << describe(F->Type) << "\n";
Env.emplace(Name, Scm);
}
Type* ReturnType = nullptr; Type* ReturnType = nullptr;
InferContext* Parent = nullptr; InferContext* Parent = nullptr;
@ -240,7 +265,7 @@ namespace bolt {
/// Environment manipulation /// Environment manipulation
Scheme* lookup(ByteString Name); Scheme* lookup(ByteString Name, SymKind Kind);
/** /**
* Looks up a type/variable and ensures that it is a monomorphic type. * Looks up a type/variable and ensures that it is a monomorphic type.
@ -254,9 +279,9 @@ namespace bolt {
* \returns If the type/variable could not be found `nullptr` is returned. * \returns If the type/variable could not be found `nullptr` is returned.
* Otherwise, a [Type] is returned. * Otherwise, a [Type] is returned.
*/ */
Type* lookupMono(ByteString Name); Type* lookupMono(ByteString Name, SymKind Kind);
void addBinding(ByteString Name, Scheme* Scm); void addBinding(ByteString Name, Scheme* Scm, SymKind Kind);
/// Constraint solving /// Constraint solving

View file

@ -55,12 +55,12 @@ namespace bolt {
UnitType = new Type(TTuple({})); UnitType = new Type(TTuple({}));
} }
Scheme* Checker::lookup(ByteString Name) { Scheme* Checker::lookup(ByteString Name, SymKind Kind) {
auto Curr = &getContext(); auto Curr = &getContext();
for (;;) { for (;;) {
auto Match = Curr->Env.find(Name); auto Match = Curr->Env.lookup(Name, Kind);
if (Match != Curr->Env.end()) { if (Match != nullptr) {
return Match->second; return Match;
} }
Curr = Curr->Parent; Curr = Curr->Parent;
if (!Curr) { if (!Curr) {
@ -70,8 +70,8 @@ namespace bolt {
return nullptr; return nullptr;
} }
Type* Checker::lookupMono(ByteString Name) { Type* Checker::lookupMono(ByteString Name, SymKind Kind) {
auto Scm = lookup(Name); auto Scm = lookup(Name, Kind);
if (Scm == nullptr) { if (Scm == nullptr) {
return nullptr; return nullptr;
} }
@ -80,8 +80,8 @@ namespace bolt {
return F->Type; return F->Type;
} }
void Checker::addBinding(ByteString Name, Scheme* Scm) { void Checker::addBinding(ByteString Name, Scheme* Scm, SymKind Kind) {
getContext().add(Name, Scm); getContext().Env.add(Name, Scm, Kind);
} }
Type* Checker::getReturnType() { Type* Checker::getReturnType() {
@ -296,29 +296,33 @@ namespace bolt {
Type* Ty = createConType(Decl->Name->getCanonicalText()); Type* Ty = createConType(Decl->Name->getCanonicalText());
// Build the type that is actually returned by constructor functions
auto RetTy = Ty;
for (auto Var: Vars) {
RetTy = new Type(TApp(RetTy, Var));
}
// Must be added early so we can create recursive types // Must be added early so we can create recursive types
Decl->Ctx->Parent->add(Decl->Name->getCanonicalText(), new Forall(Ty)); Decl->Ctx->Parent->Env.add(Decl->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
for (auto Member: Decl->Members) { for (auto Member: Decl->Members) {
switch (Member->getKind()) { switch (Member->getKind()) {
case NodeKind::TupleVariantDeclarationMember: case NodeKind::TupleVariantDeclarationMember:
{ {
auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member); auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member);
auto RetTy = Ty;
for (auto Var: Vars) {
RetTy = new Type(TApp(RetTy, Var));
}
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
for (auto Element: TupleMember->Elements) { for (auto Element: TupleMember->Elements) {
// inferTypeExpression will look up any TVars that were part of the signature of Decl
ParamTypes.push_back(inferTypeExpression(Element)); ParamTypes.push_back(inferTypeExpression(Element));
} }
Decl->Ctx->Parent->add( Decl->Ctx->Parent->Env.add(
TupleMember->Name->getCanonicalText(), TupleMember->Name->getCanonicalText(),
new Forall( new Forall(
Decl->Ctx->TVs, Decl->Ctx->TVs,
Decl->Ctx->Constraints, Decl->Ctx->Constraints,
Type::buildArrow(ParamTypes, RetTy) Type::buildArrow(ParamTypes, RetTy)
) ),
SymKind::Var
); );
break; break;
} }
@ -353,7 +357,12 @@ namespace bolt {
auto Ty = createConType(Name); auto Ty = createConType(Name);
// Must be added early so we can create recursive types // Must be added early so we can create recursive types
Decl->Ctx->Parent->add(Name, new Forall(Ty)); Decl->Ctx->Parent->Env.add(Name, new Forall(Ty), SymKind::Type);
Type* RetTy = Ty;
for (auto TV: Vars) {
RetTy = new Type(TApp(RetTy, TV));
}
// Corresponds to the logic of one branch of a VariantDeclarationMember // Corresponds to the logic of one branch of a VariantDeclarationMember
Type* FieldsTy = new Type(TNil()); Type* FieldsTy = new Type(TNil());
@ -366,18 +375,16 @@ namespace bolt {
) )
); );
} }
Type* RetTy = Ty; Decl->Ctx->Parent->Env.add(
for (auto TV: Vars) {
RetTy = new Type(TApp(RetTy, TV));
}
Decl->Ctx->Parent->add(
Name, Name,
new Forall( new Forall(
Decl->Ctx->TVs, Decl->Ctx->TVs,
Decl->Ctx->Constraints, Decl->Ctx->Constraints,
new Type(TArrow(FieldsTy, RetTy)) new Type(TArrow(FieldsTy, RetTy))
) ),
SymKind::Var
); );
popContext(); popContext();
break; break;
@ -463,7 +470,7 @@ namespace bolt {
auto Name = TE->Name->getCanonicalText(); auto Name = TE->Name->getCanonicalText();
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
TV->asVar().Context.emplace(Id); TV->asVar().Context.emplace(Id);
Ctx->add(Name, new Forall(TV)); Ctx->Env.add(Name, new Forall(TV), SymKind::Type);
Out.push_back(TV); Out.push_back(TV);
} }
return Out; return Out;
@ -553,7 +560,7 @@ namespace bolt {
} }
if (!Let->isInstance()) { if (!Let->isInstance()) {
Let->Ctx->Parent->add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty)); Let->Ctx->Parent->Env.add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty), SymKind::Var);
} }
} }
@ -808,7 +815,7 @@ namespace bolt {
case NodeKind::ReferenceTypeExpression: case NodeKind::ReferenceTypeExpression:
{ {
auto RefTE = static_cast<ReferenceTypeExpression*>(N); auto RefTE = static_cast<ReferenceTypeExpression*>(N);
auto Scm = lookup(RefTE->Name->getCanonicalText()); auto Scm = lookup(RefTE->Name->getCanonicalText(), SymKind::Type);
Type* Ty; Type* Ty;
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name); DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name);
@ -834,13 +841,13 @@ namespace bolt {
case NodeKind::VarTypeExpression: case NodeKind::VarTypeExpression:
{ {
auto VarTE = static_cast<VarTypeExpression*>(N); auto VarTE = static_cast<VarTypeExpression*>(N);
auto Ty = lookupMono(VarTE->Name->getCanonicalText()); auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type);
if (Ty == nullptr) { if (Ty == nullptr) {
if (IsPoly && Config.typeVarsRequireForall()) { if (IsPoly && Config.typeVarsRequireForall()) {
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name); DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
} }
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar(); Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
} }
ZEN_ASSERT(Ty->isVar()); ZEN_ASSERT(Ty->isVar());
N->setType(Ty); N->setType(Ty);
@ -974,7 +981,7 @@ namespace bolt {
auto Ref = static_cast<ReferenceExpression*>(X); auto Ref = static_cast<ReferenceExpression*>(X);
ZEN_ASSERT(Ref->ModulePath.empty()); ZEN_ASSERT(Ref->ModulePath.empty());
if (Ref->Name->is<IdentifierAlt>()) { if (Ref->Name->is<IdentifierAlt>()) {
auto Scm = lookup(Ref->Name->getCanonicalText()); auto Scm = lookup(Ref->Name->getCanonicalText(), SymKind::Var);
if (!Scm) { if (!Scm) {
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name); DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
Ty = createTypeVar(); Ty = createTypeVar();
@ -999,7 +1006,7 @@ namespace bolt {
infer(Let); infer(Let);
} }
} }
auto Scm = lookup(Ref->Name->getCanonicalText()); auto Scm = lookup(Ref->Name->getCanonicalText(), SymKind::Var);
ZEN_ASSERT(Scm); ZEN_ASSERT(Scm);
Ty = instantiate(Scm, X); Ty = instantiate(Scm, X);
break; break;
@ -1021,7 +1028,7 @@ namespace bolt {
case NodeKind::InfixExpression: case NodeKind::InfixExpression:
{ {
auto Infix = static_cast<InfixExpression*>(X); auto Infix = static_cast<InfixExpression*>(X);
auto Scm = lookup(Infix->Operator->getText()); auto Scm = lookup(Infix->Operator->getText(), SymKind::Var);
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator); DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator);
Ty = createTypeVar(); Ty = createTypeVar();
@ -1102,14 +1109,14 @@ namespace bolt {
{ {
auto P = static_cast<BindPattern*>(Pattern); auto P = static_cast<BindPattern*>(Pattern);
auto Ty = createTypeVar(); auto Ty = createTypeVar();
addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty)); addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty), SymKind::Var);
return Ty; return Ty;
} }
case NodeKind::NamedPattern: case NodeKind::NamedPattern:
{ {
auto P = static_cast<NamedPattern*>(Pattern); auto P = static_cast<NamedPattern*>(Pattern);
auto Scm = lookup(P->Name->getCanonicalText()); auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
for (auto P2: P->Patterns) { for (auto P2: P->Patterns) {
ParamTypes.push_back(inferPattern(P2, Constraints, TVs)); ParamTypes.push_back(inferPattern(P2, Constraints, TVs));
@ -1167,10 +1174,10 @@ namespace bolt {
Type* Ty; Type* Ty;
switch (L->getKind()) { switch (L->getKind()) {
case NodeKind::IntegerLiteral: case NodeKind::IntegerLiteral:
Ty = lookupMono("Int"); Ty = lookupMono("Int", SymKind::Type);
break; break;
case NodeKind::StringLiteral: case NodeKind::StringLiteral:
Ty = lookupMono("String"); Ty = lookupMono("String", SymKind::Type);
break; break;
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
@ -1235,18 +1242,18 @@ namespace bolt {
void Checker::check(SourceFile *SF) { void Checker::check(SourceFile *SF) {
initialize(SF); initialize(SF);
setContext(SF->Ctx); setContext(SF->Ctx);
addBinding("String", new Forall(StringType)); addBinding("String", new Forall(StringType), SymKind::Type);
addBinding("Int", new Forall(IntType)); addBinding("Int", new Forall(IntType), SymKind::Type);
addBinding("Bool", new Forall(BoolType)); addBinding("Bool", new Forall(BoolType), SymKind::Type);
addBinding("List", new Forall(ListType)); addBinding("List", new Forall(ListType), SymKind::Type);
addBinding("True", new Forall(BoolType)); addBinding("True", new Forall(BoolType), SymKind::Var);
addBinding("False", new Forall(BoolType)); addBinding("False", new Forall(BoolType), SymKind::Var);
auto A = createTypeVar(); auto A = createTypeVar();
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType))); addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)), SymKind::Var);
addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType))); addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
populate(SF); populate(SF);
forwardDeclare(SF); forwardDeclare(SF);
auto SCCs = RefGraph.strongconnect(); auto SCCs = RefGraph.strongconnect();

View file

@ -1,10 +1,10 @@
enum List a. enum MyList a.
Nil Nil
Pair a (List a) Pair a (MyList a)
let x : List Int let x : MyList Int
@expect_diagnostic 2010 @expect_diagnostic 2010
let y : List Bool = x let y : MyList Bool = x
let z : List String let z : MyList String