Make TypeEnv sort variables on whether they are a var or a function
Fixes RecordDeclaration and VariantDeclaration not working correctly
This commit is contained in:
parent
7b58a6c51f
commit
7ac3c39164
3 changed files with 101 additions and 69 deletions
|
@ -1,21 +1,28 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
|
||||
#include "zen/tuple_hash.hpp"
|
||||
|
||||
#include "bolt/ByteString.hpp"
|
||||
#include "bolt/Common.hpp"
|
||||
#include "bolt/CST.hpp"
|
||||
#include "bolt/Type.hpp"
|
||||
#include "bolt/Support/Graph.hpp"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
|
||||
namespace bolt {
|
||||
|
||||
std::string describe(const Type* Ty); // For debugging only
|
||||
|
||||
enum class SymKind {
|
||||
Type,
|
||||
Var,
|
||||
};
|
||||
|
||||
class DiagnosticEngine;
|
||||
|
||||
class Constraint;
|
||||
|
@ -70,7 +77,35 @@ namespace bolt {
|
|||
|
||||
};
|
||||
|
||||
using TypeEnv = std::unordered_map<ByteString, Scheme*>;
|
||||
class TypeEnv {
|
||||
|
||||
std::unordered_map<std::tuple<ByteString, SymKind>, Scheme*> Mapping;
|
||||
|
||||
public:
|
||||
|
||||
Scheme* lookup(ByteString Name, SymKind Kind) {
|
||||
auto Key = std::make_tuple(Name, Kind);
|
||||
auto Match = Mapping.find(Key);
|
||||
if (Match == Mapping.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return Match->second;
|
||||
}
|
||||
|
||||
void add(ByteString Name, Scheme* Scm, SymKind Kind) {
|
||||
auto Key = std::make_tuple(Name, Kind);
|
||||
ZEN_ASSERT(!Mapping.count(Key))
|
||||
// auto F = static_cast<Forall*>(Scm);
|
||||
// std::cerr << Name << " : forall ";
|
||||
// for (auto TV: *F->TVs) {
|
||||
// std::cerr << describe(TV) << " ";
|
||||
// }
|
||||
// std::cerr << ". " << describe(F->Type) << "\n";
|
||||
Mapping.emplace(Key, Scm);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
enum class ConstraintKind {
|
||||
Equal,
|
||||
|
@ -158,16 +193,6 @@ namespace bolt {
|
|||
|
||||
TypeEnv Env;
|
||||
|
||||
void add(ByteString Name, Scheme* Scm) {
|
||||
// auto F = static_cast<Forall*>(Scm);
|
||||
// std::cerr << Name << " : forall ";
|
||||
// for (auto TV: *F->TVs) {
|
||||
// std::cerr << describe(TV) << " ";
|
||||
// }
|
||||
// std::cerr << ". " << describe(F->Type) << "\n";
|
||||
Env.emplace(Name, Scm);
|
||||
}
|
||||
|
||||
Type* ReturnType = nullptr;
|
||||
|
||||
InferContext* Parent = nullptr;
|
||||
|
@ -240,7 +265,7 @@ namespace bolt {
|
|||
|
||||
/// Environment manipulation
|
||||
|
||||
Scheme* lookup(ByteString Name);
|
||||
Scheme* lookup(ByteString Name, SymKind Kind);
|
||||
|
||||
/**
|
||||
* Looks up a type/variable and ensures that it is a monomorphic type.
|
||||
|
@ -254,9 +279,9 @@ namespace bolt {
|
|||
* \returns If the type/variable could not be found `nullptr` is returned.
|
||||
* Otherwise, a [Type] is returned.
|
||||
*/
|
||||
Type* lookupMono(ByteString Name);
|
||||
Type* lookupMono(ByteString Name, SymKind Kind);
|
||||
|
||||
void addBinding(ByteString Name, Scheme* Scm);
|
||||
void addBinding(ByteString Name, Scheme* Scm, SymKind Kind);
|
||||
|
||||
/// Constraint solving
|
||||
|
||||
|
|
|
@ -55,12 +55,12 @@ namespace bolt {
|
|||
UnitType = new Type(TTuple({}));
|
||||
}
|
||||
|
||||
Scheme* Checker::lookup(ByteString Name) {
|
||||
Scheme* Checker::lookup(ByteString Name, SymKind Kind) {
|
||||
auto Curr = &getContext();
|
||||
for (;;) {
|
||||
auto Match = Curr->Env.find(Name);
|
||||
if (Match != Curr->Env.end()) {
|
||||
return Match->second;
|
||||
auto Match = Curr->Env.lookup(Name, Kind);
|
||||
if (Match != nullptr) {
|
||||
return Match;
|
||||
}
|
||||
Curr = Curr->Parent;
|
||||
if (!Curr) {
|
||||
|
@ -70,8 +70,8 @@ namespace bolt {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
Type* Checker::lookupMono(ByteString Name) {
|
||||
auto Scm = lookup(Name);
|
||||
Type* Checker::lookupMono(ByteString Name, SymKind Kind) {
|
||||
auto Scm = lookup(Name, Kind);
|
||||
if (Scm == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -80,8 +80,8 @@ namespace bolt {
|
|||
return F->Type;
|
||||
}
|
||||
|
||||
void Checker::addBinding(ByteString Name, Scheme* Scm) {
|
||||
getContext().add(Name, Scm);
|
||||
void Checker::addBinding(ByteString Name, Scheme* Scm, SymKind Kind) {
|
||||
getContext().Env.add(Name, Scm, Kind);
|
||||
}
|
||||
|
||||
Type* Checker::getReturnType() {
|
||||
|
@ -296,29 +296,33 @@ namespace bolt {
|
|||
|
||||
Type* Ty = createConType(Decl->Name->getCanonicalText());
|
||||
|
||||
// Build the type that is actually returned by constructor functions
|
||||
auto RetTy = Ty;
|
||||
for (auto Var: Vars) {
|
||||
RetTy = new Type(TApp(RetTy, Var));
|
||||
}
|
||||
|
||||
// Must be added early so we can create recursive types
|
||||
Decl->Ctx->Parent->add(Decl->Name->getCanonicalText(), new Forall(Ty));
|
||||
Decl->Ctx->Parent->Env.add(Decl->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
|
||||
|
||||
for (auto Member: Decl->Members) {
|
||||
switch (Member->getKind()) {
|
||||
case NodeKind::TupleVariantDeclarationMember:
|
||||
{
|
||||
auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member);
|
||||
auto RetTy = Ty;
|
||||
for (auto Var: Vars) {
|
||||
RetTy = new Type(TApp(RetTy, Var));
|
||||
}
|
||||
std::vector<Type*> ParamTypes;
|
||||
for (auto Element: TupleMember->Elements) {
|
||||
// inferTypeExpression will look up any TVars that were part of the signature of Decl
|
||||
ParamTypes.push_back(inferTypeExpression(Element));
|
||||
}
|
||||
Decl->Ctx->Parent->add(
|
||||
Decl->Ctx->Parent->Env.add(
|
||||
TupleMember->Name->getCanonicalText(),
|
||||
new Forall(
|
||||
Decl->Ctx->TVs,
|
||||
Decl->Ctx->Constraints,
|
||||
Type::buildArrow(ParamTypes, RetTy)
|
||||
)
|
||||
),
|
||||
SymKind::Var
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
@ -353,7 +357,12 @@ namespace bolt {
|
|||
auto Ty = createConType(Name);
|
||||
|
||||
// Must be added early so we can create recursive types
|
||||
Decl->Ctx->Parent->add(Name, new Forall(Ty));
|
||||
Decl->Ctx->Parent->Env.add(Name, new Forall(Ty), SymKind::Type);
|
||||
|
||||
Type* RetTy = Ty;
|
||||
for (auto TV: Vars) {
|
||||
RetTy = new Type(TApp(RetTy, TV));
|
||||
}
|
||||
|
||||
// Corresponds to the logic of one branch of a VariantDeclarationMember
|
||||
Type* FieldsTy = new Type(TNil());
|
||||
|
@ -366,18 +375,16 @@ namespace bolt {
|
|||
)
|
||||
);
|
||||
}
|
||||
Type* RetTy = Ty;
|
||||
for (auto TV: Vars) {
|
||||
RetTy = new Type(TApp(RetTy, TV));
|
||||
}
|
||||
Decl->Ctx->Parent->add(
|
||||
Decl->Ctx->Parent->Env.add(
|
||||
Name,
|
||||
new Forall(
|
||||
Decl->Ctx->TVs,
|
||||
Decl->Ctx->Constraints,
|
||||
new Type(TArrow(FieldsTy, RetTy))
|
||||
)
|
||||
),
|
||||
SymKind::Var
|
||||
);
|
||||
|
||||
popContext();
|
||||
|
||||
break;
|
||||
|
@ -463,7 +470,7 @@ namespace bolt {
|
|||
auto Name = TE->Name->getCanonicalText();
|
||||
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
|
||||
TV->asVar().Context.emplace(Id);
|
||||
Ctx->add(Name, new Forall(TV));
|
||||
Ctx->Env.add(Name, new Forall(TV), SymKind::Type);
|
||||
Out.push_back(TV);
|
||||
}
|
||||
return Out;
|
||||
|
@ -553,7 +560,7 @@ namespace bolt {
|
|||
}
|
||||
|
||||
if (!Let->isInstance()) {
|
||||
Let->Ctx->Parent->add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty));
|
||||
Let->Ctx->Parent->Env.add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty), SymKind::Var);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -808,7 +815,7 @@ namespace bolt {
|
|||
case NodeKind::ReferenceTypeExpression:
|
||||
{
|
||||
auto RefTE = static_cast<ReferenceTypeExpression*>(N);
|
||||
auto Scm = lookup(RefTE->Name->getCanonicalText());
|
||||
auto Scm = lookup(RefTE->Name->getCanonicalText(), SymKind::Type);
|
||||
Type* Ty;
|
||||
if (Scm == nullptr) {
|
||||
DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name);
|
||||
|
@ -834,13 +841,13 @@ namespace bolt {
|
|||
case NodeKind::VarTypeExpression:
|
||||
{
|
||||
auto VarTE = static_cast<VarTypeExpression*>(N);
|
||||
auto Ty = lookupMono(VarTE->Name->getCanonicalText());
|
||||
auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type);
|
||||
if (Ty == nullptr) {
|
||||
if (IsPoly && Config.typeVarsRequireForall()) {
|
||||
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
|
||||
}
|
||||
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
|
||||
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty));
|
||||
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
|
||||
}
|
||||
ZEN_ASSERT(Ty->isVar());
|
||||
N->setType(Ty);
|
||||
|
@ -974,7 +981,7 @@ namespace bolt {
|
|||
auto Ref = static_cast<ReferenceExpression*>(X);
|
||||
ZEN_ASSERT(Ref->ModulePath.empty());
|
||||
if (Ref->Name->is<IdentifierAlt>()) {
|
||||
auto Scm = lookup(Ref->Name->getCanonicalText());
|
||||
auto Scm = lookup(Ref->Name->getCanonicalText(), SymKind::Var);
|
||||
if (!Scm) {
|
||||
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
|
||||
Ty = createTypeVar();
|
||||
|
@ -999,7 +1006,7 @@ namespace bolt {
|
|||
infer(Let);
|
||||
}
|
||||
}
|
||||
auto Scm = lookup(Ref->Name->getCanonicalText());
|
||||
auto Scm = lookup(Ref->Name->getCanonicalText(), SymKind::Var);
|
||||
ZEN_ASSERT(Scm);
|
||||
Ty = instantiate(Scm, X);
|
||||
break;
|
||||
|
@ -1021,7 +1028,7 @@ namespace bolt {
|
|||
case NodeKind::InfixExpression:
|
||||
{
|
||||
auto Infix = static_cast<InfixExpression*>(X);
|
||||
auto Scm = lookup(Infix->Operator->getText());
|
||||
auto Scm = lookup(Infix->Operator->getText(), SymKind::Var);
|
||||
if (Scm == nullptr) {
|
||||
DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator);
|
||||
Ty = createTypeVar();
|
||||
|
@ -1102,14 +1109,14 @@ namespace bolt {
|
|||
{
|
||||
auto P = static_cast<BindPattern*>(Pattern);
|
||||
auto Ty = createTypeVar();
|
||||
addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty));
|
||||
addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty), SymKind::Var);
|
||||
return Ty;
|
||||
}
|
||||
|
||||
case NodeKind::NamedPattern:
|
||||
{
|
||||
auto P = static_cast<NamedPattern*>(Pattern);
|
||||
auto Scm = lookup(P->Name->getCanonicalText());
|
||||
auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
|
||||
std::vector<Type*> ParamTypes;
|
||||
for (auto P2: P->Patterns) {
|
||||
ParamTypes.push_back(inferPattern(P2, Constraints, TVs));
|
||||
|
@ -1167,10 +1174,10 @@ namespace bolt {
|
|||
Type* Ty;
|
||||
switch (L->getKind()) {
|
||||
case NodeKind::IntegerLiteral:
|
||||
Ty = lookupMono("Int");
|
||||
Ty = lookupMono("Int", SymKind::Type);
|
||||
break;
|
||||
case NodeKind::StringLiteral:
|
||||
Ty = lookupMono("String");
|
||||
Ty = lookupMono("String", SymKind::Type);
|
||||
break;
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
|
@ -1235,18 +1242,18 @@ namespace bolt {
|
|||
void Checker::check(SourceFile *SF) {
|
||||
initialize(SF);
|
||||
setContext(SF->Ctx);
|
||||
addBinding("String", new Forall(StringType));
|
||||
addBinding("Int", new Forall(IntType));
|
||||
addBinding("Bool", new Forall(BoolType));
|
||||
addBinding("List", new Forall(ListType));
|
||||
addBinding("True", new Forall(BoolType));
|
||||
addBinding("False", new Forall(BoolType));
|
||||
addBinding("String", new Forall(StringType), SymKind::Type);
|
||||
addBinding("Int", new Forall(IntType), SymKind::Type);
|
||||
addBinding("Bool", new Forall(BoolType), SymKind::Type);
|
||||
addBinding("List", new Forall(ListType), SymKind::Type);
|
||||
addBinding("True", new Forall(BoolType), SymKind::Var);
|
||||
addBinding("False", new Forall(BoolType), SymKind::Var);
|
||||
auto A = createTypeVar();
|
||||
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)));
|
||||
addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)), SymKind::Var);
|
||||
addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
|
||||
addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
|
||||
addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
|
||||
addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
|
||||
populate(SF);
|
||||
forwardDeclare(SF);
|
||||
auto SCCs = RefGraph.strongconnect();
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
|
||||
enum List a.
|
||||
enum MyList a.
|
||||
Nil
|
||||
Pair a (List a)
|
||||
Pair a (MyList a)
|
||||
|
||||
let x : List Int
|
||||
let x : MyList Int
|
||||
@expect_diagnostic 2010
|
||||
let y : List Bool = x
|
||||
let z : List String
|
||||
let y : MyList Bool = x
|
||||
let z : MyList String
|
||||
|
||||
|
|
Loading…
Reference in a new issue