Make TypeEnv sort variables on whether they are a var or a function
Fixes RecordDeclaration and VariantDeclaration not working correctly
This commit is contained in:
parent
7b58a6c51f
commit
7ac3c39164
3 changed files with 101 additions and 69 deletions
|
@ -1,21 +1,28 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
#include <deque>
|
||||||
|
|
||||||
|
#include "zen/tuple_hash.hpp"
|
||||||
|
|
||||||
#include "bolt/ByteString.hpp"
|
#include "bolt/ByteString.hpp"
|
||||||
#include "bolt/Common.hpp"
|
#include "bolt/Common.hpp"
|
||||||
#include "bolt/CST.hpp"
|
#include "bolt/CST.hpp"
|
||||||
#include "bolt/Type.hpp"
|
#include "bolt/Type.hpp"
|
||||||
#include "bolt/Support/Graph.hpp"
|
#include "bolt/Support/Graph.hpp"
|
||||||
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
|
||||||
#include <deque>
|
|
||||||
|
|
||||||
namespace bolt {
|
namespace bolt {
|
||||||
|
|
||||||
std::string describe(const Type* Ty); // For debugging only
|
std::string describe(const Type* Ty); // For debugging only
|
||||||
|
|
||||||
|
enum class SymKind {
|
||||||
|
Type,
|
||||||
|
Var,
|
||||||
|
};
|
||||||
|
|
||||||
class DiagnosticEngine;
|
class DiagnosticEngine;
|
||||||
|
|
||||||
class Constraint;
|
class Constraint;
|
||||||
|
@ -70,7 +77,35 @@ namespace bolt {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using TypeEnv = std::unordered_map<ByteString, Scheme*>;
|
class TypeEnv {
|
||||||
|
|
||||||
|
std::unordered_map<std::tuple<ByteString, SymKind>, Scheme*> Mapping;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
Scheme* lookup(ByteString Name, SymKind Kind) {
|
||||||
|
auto Key = std::make_tuple(Name, Kind);
|
||||||
|
auto Match = Mapping.find(Key);
|
||||||
|
if (Match == Mapping.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return Match->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(ByteString Name, Scheme* Scm, SymKind Kind) {
|
||||||
|
auto Key = std::make_tuple(Name, Kind);
|
||||||
|
ZEN_ASSERT(!Mapping.count(Key))
|
||||||
|
// auto F = static_cast<Forall*>(Scm);
|
||||||
|
// std::cerr << Name << " : forall ";
|
||||||
|
// for (auto TV: *F->TVs) {
|
||||||
|
// std::cerr << describe(TV) << " ";
|
||||||
|
// }
|
||||||
|
// std::cerr << ". " << describe(F->Type) << "\n";
|
||||||
|
Mapping.emplace(Key, Scm);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
enum class ConstraintKind {
|
enum class ConstraintKind {
|
||||||
Equal,
|
Equal,
|
||||||
|
@ -158,16 +193,6 @@ namespace bolt {
|
||||||
|
|
||||||
TypeEnv Env;
|
TypeEnv Env;
|
||||||
|
|
||||||
void add(ByteString Name, Scheme* Scm) {
|
|
||||||
// auto F = static_cast<Forall*>(Scm);
|
|
||||||
// std::cerr << Name << " : forall ";
|
|
||||||
// for (auto TV: *F->TVs) {
|
|
||||||
// std::cerr << describe(TV) << " ";
|
|
||||||
// }
|
|
||||||
// std::cerr << ". " << describe(F->Type) << "\n";
|
|
||||||
Env.emplace(Name, Scm);
|
|
||||||
}
|
|
||||||
|
|
||||||
Type* ReturnType = nullptr;
|
Type* ReturnType = nullptr;
|
||||||
|
|
||||||
InferContext* Parent = nullptr;
|
InferContext* Parent = nullptr;
|
||||||
|
@ -240,7 +265,7 @@ namespace bolt {
|
||||||
|
|
||||||
/// Environment manipulation
|
/// Environment manipulation
|
||||||
|
|
||||||
Scheme* lookup(ByteString Name);
|
Scheme* lookup(ByteString Name, SymKind Kind);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Looks up a type/variable and ensures that it is a monomorphic type.
|
* Looks up a type/variable and ensures that it is a monomorphic type.
|
||||||
|
@ -254,9 +279,9 @@ namespace bolt {
|
||||||
* \returns If the type/variable could not be found `nullptr` is returned.
|
* \returns If the type/variable could not be found `nullptr` is returned.
|
||||||
* Otherwise, a [Type] is returned.
|
* Otherwise, a [Type] is returned.
|
||||||
*/
|
*/
|
||||||
Type* lookupMono(ByteString Name);
|
Type* lookupMono(ByteString Name, SymKind Kind);
|
||||||
|
|
||||||
void addBinding(ByteString Name, Scheme* Scm);
|
void addBinding(ByteString Name, Scheme* Scm, SymKind Kind);
|
||||||
|
|
||||||
/// Constraint solving
|
/// Constraint solving
|
||||||
|
|
||||||
|
|
|
@ -55,12 +55,12 @@ namespace bolt {
|
||||||
UnitType = new Type(TTuple({}));
|
UnitType = new Type(TTuple({}));
|
||||||
}
|
}
|
||||||
|
|
||||||
Scheme* Checker::lookup(ByteString Name) {
|
Scheme* Checker::lookup(ByteString Name, SymKind Kind) {
|
||||||
auto Curr = &getContext();
|
auto Curr = &getContext();
|
||||||
for (;;) {
|
for (;;) {
|
||||||
auto Match = Curr->Env.find(Name);
|
auto Match = Curr->Env.lookup(Name, Kind);
|
||||||
if (Match != Curr->Env.end()) {
|
if (Match != nullptr) {
|
||||||
return Match->second;
|
return Match;
|
||||||
}
|
}
|
||||||
Curr = Curr->Parent;
|
Curr = Curr->Parent;
|
||||||
if (!Curr) {
|
if (!Curr) {
|
||||||
|
@ -70,8 +70,8 @@ namespace bolt {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* Checker::lookupMono(ByteString Name) {
|
Type* Checker::lookupMono(ByteString Name, SymKind Kind) {
|
||||||
auto Scm = lookup(Name);
|
auto Scm = lookup(Name, Kind);
|
||||||
if (Scm == nullptr) {
|
if (Scm == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -80,8 +80,8 @@ namespace bolt {
|
||||||
return F->Type;
|
return F->Type;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::addBinding(ByteString Name, Scheme* Scm) {
|
void Checker::addBinding(ByteString Name, Scheme* Scm, SymKind Kind) {
|
||||||
getContext().add(Name, Scm);
|
getContext().Env.add(Name, Scm, Kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* Checker::getReturnType() {
|
Type* Checker::getReturnType() {
|
||||||
|
@ -296,29 +296,33 @@ namespace bolt {
|
||||||
|
|
||||||
Type* Ty = createConType(Decl->Name->getCanonicalText());
|
Type* Ty = createConType(Decl->Name->getCanonicalText());
|
||||||
|
|
||||||
|
// Build the type that is actually returned by constructor functions
|
||||||
|
auto RetTy = Ty;
|
||||||
|
for (auto Var: Vars) {
|
||||||
|
RetTy = new Type(TApp(RetTy, Var));
|
||||||
|
}
|
||||||
|
|
||||||
// Must be added early so we can create recursive types
|
// Must be added early so we can create recursive types
|
||||||
Decl->Ctx->Parent->add(Decl->Name->getCanonicalText(), new Forall(Ty));
|
Decl->Ctx->Parent->Env.add(Decl->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
|
||||||
|
|
||||||
for (auto Member: Decl->Members) {
|
for (auto Member: Decl->Members) {
|
||||||
switch (Member->getKind()) {
|
switch (Member->getKind()) {
|
||||||
case NodeKind::TupleVariantDeclarationMember:
|
case NodeKind::TupleVariantDeclarationMember:
|
||||||
{
|
{
|
||||||
auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member);
|
auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member);
|
||||||
auto RetTy = Ty;
|
|
||||||
for (auto Var: Vars) {
|
|
||||||
RetTy = new Type(TApp(RetTy, Var));
|
|
||||||
}
|
|
||||||
std::vector<Type*> ParamTypes;
|
std::vector<Type*> ParamTypes;
|
||||||
for (auto Element: TupleMember->Elements) {
|
for (auto Element: TupleMember->Elements) {
|
||||||
|
// inferTypeExpression will look up any TVars that were part of the signature of Decl
|
||||||
ParamTypes.push_back(inferTypeExpression(Element));
|
ParamTypes.push_back(inferTypeExpression(Element));
|
||||||
}
|
}
|
||||||
Decl->Ctx->Parent->add(
|
Decl->Ctx->Parent->Env.add(
|
||||||
TupleMember->Name->getCanonicalText(),
|
TupleMember->Name->getCanonicalText(),
|
||||||
new Forall(
|
new Forall(
|
||||||
Decl->Ctx->TVs,
|
Decl->Ctx->TVs,
|
||||||
Decl->Ctx->Constraints,
|
Decl->Ctx->Constraints,
|
||||||
Type::buildArrow(ParamTypes, RetTy)
|
Type::buildArrow(ParamTypes, RetTy)
|
||||||
)
|
),
|
||||||
|
SymKind::Var
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -353,7 +357,12 @@ namespace bolt {
|
||||||
auto Ty = createConType(Name);
|
auto Ty = createConType(Name);
|
||||||
|
|
||||||
// Must be added early so we can create recursive types
|
// Must be added early so we can create recursive types
|
||||||
Decl->Ctx->Parent->add(Name, new Forall(Ty));
|
Decl->Ctx->Parent->Env.add(Name, new Forall(Ty), SymKind::Type);
|
||||||
|
|
||||||
|
Type* RetTy = Ty;
|
||||||
|
for (auto TV: Vars) {
|
||||||
|
RetTy = new Type(TApp(RetTy, TV));
|
||||||
|
}
|
||||||
|
|
||||||
// Corresponds to the logic of one branch of a VariantDeclarationMember
|
// Corresponds to the logic of one branch of a VariantDeclarationMember
|
||||||
Type* FieldsTy = new Type(TNil());
|
Type* FieldsTy = new Type(TNil());
|
||||||
|
@ -366,18 +375,16 @@ namespace bolt {
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Type* RetTy = Ty;
|
Decl->Ctx->Parent->Env.add(
|
||||||
for (auto TV: Vars) {
|
|
||||||
RetTy = new Type(TApp(RetTy, TV));
|
|
||||||
}
|
|
||||||
Decl->Ctx->Parent->add(
|
|
||||||
Name,
|
Name,
|
||||||
new Forall(
|
new Forall(
|
||||||
Decl->Ctx->TVs,
|
Decl->Ctx->TVs,
|
||||||
Decl->Ctx->Constraints,
|
Decl->Ctx->Constraints,
|
||||||
new Type(TArrow(FieldsTy, RetTy))
|
new Type(TArrow(FieldsTy, RetTy))
|
||||||
)
|
),
|
||||||
|
SymKind::Var
|
||||||
);
|
);
|
||||||
|
|
||||||
popContext();
|
popContext();
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
@ -463,7 +470,7 @@ namespace bolt {
|
||||||
auto Name = TE->Name->getCanonicalText();
|
auto Name = TE->Name->getCanonicalText();
|
||||||
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
|
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
|
||||||
TV->asVar().Context.emplace(Id);
|
TV->asVar().Context.emplace(Id);
|
||||||
Ctx->add(Name, new Forall(TV));
|
Ctx->Env.add(Name, new Forall(TV), SymKind::Type);
|
||||||
Out.push_back(TV);
|
Out.push_back(TV);
|
||||||
}
|
}
|
||||||
return Out;
|
return Out;
|
||||||
|
@ -553,7 +560,7 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!Let->isInstance()) {
|
if (!Let->isInstance()) {
|
||||||
Let->Ctx->Parent->add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty));
|
Let->Ctx->Parent->Env.add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty), SymKind::Var);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -808,7 +815,7 @@ namespace bolt {
|
||||||
case NodeKind::ReferenceTypeExpression:
|
case NodeKind::ReferenceTypeExpression:
|
||||||
{
|
{
|
||||||
auto RefTE = static_cast<ReferenceTypeExpression*>(N);
|
auto RefTE = static_cast<ReferenceTypeExpression*>(N);
|
||||||
auto Scm = lookup(RefTE->Name->getCanonicalText());
|
auto Scm = lookup(RefTE->Name->getCanonicalText(), SymKind::Type);
|
||||||
Type* Ty;
|
Type* Ty;
|
||||||
if (Scm == nullptr) {
|
if (Scm == nullptr) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name);
|
DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name);
|
||||||
|
@ -834,13 +841,13 @@ namespace bolt {
|
||||||
case NodeKind::VarTypeExpression:
|
case NodeKind::VarTypeExpression:
|
||||||
{
|
{
|
||||||
auto VarTE = static_cast<VarTypeExpression*>(N);
|
auto VarTE = static_cast<VarTypeExpression*>(N);
|
||||||
auto Ty = lookupMono(VarTE->Name->getCanonicalText());
|
auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type);
|
||||||
if (Ty == nullptr) {
|
if (Ty == nullptr) {
|
||||||
if (IsPoly && Config.typeVarsRequireForall()) {
|
if (IsPoly && Config.typeVarsRequireForall()) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
|
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
|
||||||
}
|
}
|
||||||
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
|
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
|
||||||
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty));
|
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
|
||||||
}
|
}
|
||||||
ZEN_ASSERT(Ty->isVar());
|
ZEN_ASSERT(Ty->isVar());
|
||||||
N->setType(Ty);
|
N->setType(Ty);
|
||||||
|
@ -974,7 +981,7 @@ namespace bolt {
|
||||||
auto Ref = static_cast<ReferenceExpression*>(X);
|
auto Ref = static_cast<ReferenceExpression*>(X);
|
||||||
ZEN_ASSERT(Ref->ModulePath.empty());
|
ZEN_ASSERT(Ref->ModulePath.empty());
|
||||||
if (Ref->Name->is<IdentifierAlt>()) {
|
if (Ref->Name->is<IdentifierAlt>()) {
|
||||||
auto Scm = lookup(Ref->Name->getCanonicalText());
|
auto Scm = lookup(Ref->Name->getCanonicalText(), SymKind::Var);
|
||||||
if (!Scm) {
|
if (!Scm) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
|
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
|
||||||
Ty = createTypeVar();
|
Ty = createTypeVar();
|
||||||
|
@ -999,7 +1006,7 @@ namespace bolt {
|
||||||
infer(Let);
|
infer(Let);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto Scm = lookup(Ref->Name->getCanonicalText());
|
auto Scm = lookup(Ref->Name->getCanonicalText(), SymKind::Var);
|
||||||
ZEN_ASSERT(Scm);
|
ZEN_ASSERT(Scm);
|
||||||
Ty = instantiate(Scm, X);
|
Ty = instantiate(Scm, X);
|
||||||
break;
|
break;
|
||||||
|
@ -1021,7 +1028,7 @@ namespace bolt {
|
||||||
case NodeKind::InfixExpression:
|
case NodeKind::InfixExpression:
|
||||||
{
|
{
|
||||||
auto Infix = static_cast<InfixExpression*>(X);
|
auto Infix = static_cast<InfixExpression*>(X);
|
||||||
auto Scm = lookup(Infix->Operator->getText());
|
auto Scm = lookup(Infix->Operator->getText(), SymKind::Var);
|
||||||
if (Scm == nullptr) {
|
if (Scm == nullptr) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator);
|
DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator);
|
||||||
Ty = createTypeVar();
|
Ty = createTypeVar();
|
||||||
|
@ -1102,14 +1109,14 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto P = static_cast<BindPattern*>(Pattern);
|
auto P = static_cast<BindPattern*>(Pattern);
|
||||||
auto Ty = createTypeVar();
|
auto Ty = createTypeVar();
|
||||||
addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty));
|
addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty), SymKind::Var);
|
||||||
return Ty;
|
return Ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
case NodeKind::NamedPattern:
|
case NodeKind::NamedPattern:
|
||||||
{
|
{
|
||||||
auto P = static_cast<NamedPattern*>(Pattern);
|
auto P = static_cast<NamedPattern*>(Pattern);
|
||||||
auto Scm = lookup(P->Name->getCanonicalText());
|
auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
|
||||||
std::vector<Type*> ParamTypes;
|
std::vector<Type*> ParamTypes;
|
||||||
for (auto P2: P->Patterns) {
|
for (auto P2: P->Patterns) {
|
||||||
ParamTypes.push_back(inferPattern(P2, Constraints, TVs));
|
ParamTypes.push_back(inferPattern(P2, Constraints, TVs));
|
||||||
|
@ -1167,10 +1174,10 @@ namespace bolt {
|
||||||
Type* Ty;
|
Type* Ty;
|
||||||
switch (L->getKind()) {
|
switch (L->getKind()) {
|
||||||
case NodeKind::IntegerLiteral:
|
case NodeKind::IntegerLiteral:
|
||||||
Ty = lookupMono("Int");
|
Ty = lookupMono("Int", SymKind::Type);
|
||||||
break;
|
break;
|
||||||
case NodeKind::StringLiteral:
|
case NodeKind::StringLiteral:
|
||||||
Ty = lookupMono("String");
|
Ty = lookupMono("String", SymKind::Type);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
ZEN_UNREACHABLE
|
ZEN_UNREACHABLE
|
||||||
|
@ -1235,18 +1242,18 @@ namespace bolt {
|
||||||
void Checker::check(SourceFile *SF) {
|
void Checker::check(SourceFile *SF) {
|
||||||
initialize(SF);
|
initialize(SF);
|
||||||
setContext(SF->Ctx);
|
setContext(SF->Ctx);
|
||||||
addBinding("String", new Forall(StringType));
|
addBinding("String", new Forall(StringType), SymKind::Type);
|
||||||
addBinding("Int", new Forall(IntType));
|
addBinding("Int", new Forall(IntType), SymKind::Type);
|
||||||
addBinding("Bool", new Forall(BoolType));
|
addBinding("Bool", new Forall(BoolType), SymKind::Type);
|
||||||
addBinding("List", new Forall(ListType));
|
addBinding("List", new Forall(ListType), SymKind::Type);
|
||||||
addBinding("True", new Forall(BoolType));
|
addBinding("True", new Forall(BoolType), SymKind::Var);
|
||||||
addBinding("False", new Forall(BoolType));
|
addBinding("False", new Forall(BoolType), SymKind::Var);
|
||||||
auto A = createTypeVar();
|
auto A = createTypeVar();
|
||||||
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)));
|
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)), SymKind::Var);
|
||||||
addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
|
addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
|
||||||
addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
|
addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
|
||||||
addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
|
addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
|
||||||
addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
|
addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var);
|
||||||
populate(SF);
|
populate(SF);
|
||||||
forwardDeclare(SF);
|
forwardDeclare(SF);
|
||||||
auto SCCs = RefGraph.strongconnect();
|
auto SCCs = RefGraph.strongconnect();
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
|
|
||||||
enum List a.
|
enum MyList a.
|
||||||
Nil
|
Nil
|
||||||
Pair a (List a)
|
Pair a (MyList a)
|
||||||
|
|
||||||
let x : List Int
|
let x : MyList Int
|
||||||
@expect_diagnostic 2010
|
@expect_diagnostic 2010
|
||||||
let y : List Bool = x
|
let y : MyList Bool = x
|
||||||
let z : List String
|
let z : MyList String
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue