Make InferContext have a parent context
This commit is contained in:
parent
83e89f4e8c
commit
a8f8658f27
3 changed files with 180 additions and 109 deletions
|
@ -17,6 +17,7 @@
|
||||||
namespace bolt {
|
namespace bolt {
|
||||||
|
|
||||||
class Type;
|
class Type;
|
||||||
|
class InferContext;
|
||||||
|
|
||||||
class Token;
|
class Token;
|
||||||
class SourceFile;
|
class SourceFile;
|
||||||
|
@ -1279,6 +1280,8 @@ namespace bolt {
|
||||||
class MatchCase : public Node {
|
class MatchCase : public Node {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
InferContext* Ctx;
|
||||||
|
|
||||||
class Pattern* Pattern;
|
class Pattern* Pattern;
|
||||||
class RArrowAlt* RArrowAlt;
|
class RArrowAlt* RArrowAlt;
|
||||||
class Expression* Expression;
|
class Expression* Expression;
|
||||||
|
@ -1658,9 +1661,6 @@ namespace bolt {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class Type;
|
|
||||||
class InferContext;
|
|
||||||
|
|
||||||
class LetDeclaration : public Node {
|
class LetDeclaration : public Node {
|
||||||
|
|
||||||
Scope* TheScope = nullptr;
|
Scope* TheScope = nullptr;
|
||||||
|
@ -1703,6 +1703,22 @@ namespace bolt {
|
||||||
return TheScope;
|
return TheScope;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isFunc() const noexcept {
|
||||||
|
return !Params.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isVar() const noexcept {
|
||||||
|
return !isFunc();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isInstance() const noexcept {
|
||||||
|
return Parent->getKind() == NodeKind::InstanceDeclaration;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isClass() const noexcept {
|
||||||
|
return Parent->getKind() == NodeKind::ClassDeclaration;
|
||||||
|
}
|
||||||
|
|
||||||
Token* getFirstToken() const override;
|
Token* getFirstToken() const override;
|
||||||
Token* getLastToken() const override;
|
Token* getLastToken() const override;
|
||||||
|
|
||||||
|
@ -1801,6 +1817,8 @@ namespace bolt {
|
||||||
class RecordDeclaration : public Node {
|
class RecordDeclaration : public Node {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
InferContext* Ctx;
|
||||||
|
|
||||||
class PubKeyword* PubKeyword;
|
class PubKeyword* PubKeyword;
|
||||||
class StructKeyword* StructKeyword;
|
class StructKeyword* StructKeyword;
|
||||||
IdentifierAlt* Name;
|
IdentifierAlt* Name;
|
||||||
|
@ -1878,6 +1896,8 @@ namespace bolt {
|
||||||
class VariantDeclaration : public Node {
|
class VariantDeclaration : public Node {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
InferContext* Ctx;
|
||||||
|
|
||||||
class PubKeyword* PubKeyword;
|
class PubKeyword* PubKeyword;
|
||||||
class EnumKeyword* EnumKeyword;
|
class EnumKeyword* EnumKeyword;
|
||||||
class IdentifierAlt* Name;
|
class IdentifierAlt* Name;
|
||||||
|
@ -1912,6 +1932,7 @@ namespace bolt {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
TextFile& File;
|
TextFile& File;
|
||||||
|
InferContext* Ctx;
|
||||||
|
|
||||||
std::vector<Node*> Elements;
|
std::vector<Node*> Elements;
|
||||||
|
|
||||||
|
|
|
@ -157,10 +157,8 @@ namespace bolt {
|
||||||
TypeEnv Env;
|
TypeEnv Env;
|
||||||
|
|
||||||
Type* ReturnType = nullptr;
|
Type* ReturnType = nullptr;
|
||||||
std::vector<TypeclassSignature> Classes;
|
|
||||||
|
|
||||||
//inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env, Type* ReturnType):
|
InferContext* Parent = nullptr;
|
||||||
// Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env), ReturnType(ReturnType) {}
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -183,22 +181,20 @@ namespace bolt {
|
||||||
|
|
||||||
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
|
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
|
||||||
|
|
||||||
std::vector<InferContext*> Contexts;
|
// std::vector<InferContext*> Contexts;
|
||||||
|
|
||||||
|
InferContext* ActiveContext;
|
||||||
|
|
||||||
|
InferContext& getContext();
|
||||||
|
void pushContext(InferContext* Ctx);
|
||||||
|
void popContext();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The queue that is used during solving to store any unsolved constraints.
|
* The queue that is used during solving to store any unsolved constraints.
|
||||||
*/
|
*/
|
||||||
std::deque<class Constraint*> Queue;
|
std::deque<class Constraint*> Queue;
|
||||||
|
|
||||||
/**
|
|
||||||
* Pointer to the current constraint being unified.
|
|
||||||
*/
|
|
||||||
CEqual* C;
|
|
||||||
|
|
||||||
InferContext& getContext();
|
|
||||||
|
|
||||||
void addConstraint(Constraint* Constraint);
|
void addConstraint(Constraint* Constraint);
|
||||||
void addClass(TypeclassSignature Sig);
|
|
||||||
|
|
||||||
void forwardDeclare(Node* Node);
|
void forwardDeclare(Node* Node);
|
||||||
void forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
|
void forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
|
||||||
|
@ -217,12 +213,18 @@ namespace bolt {
|
||||||
TCon* createConType(ByteString Name);
|
TCon* createConType(ByteString Name);
|
||||||
TVar* createTypeVar();
|
TVar* createTypeVar();
|
||||||
TVarRigid* createRigidVar(ByteString Name);
|
TVarRigid* createRigidVar(ByteString Name);
|
||||||
InferContext* createInferContext(TVSet* TVs = new TVSet, ConstraintSet* Constraints = new ConstraintSet);
|
InferContext* createInferContext(
|
||||||
|
InferContext* Parent = nullptr,
|
||||||
|
TVSet* TVs = new TVSet,
|
||||||
|
ConstraintSet* Constraints = new ConstraintSet
|
||||||
|
);
|
||||||
|
|
||||||
void addBinding(ByteString Name, Scheme* Scm);
|
void addBinding(ByteString Name, Scheme* Scm);
|
||||||
|
|
||||||
Scheme* lookup(ByteString Name);
|
Scheme* lookup(ByteString Name);
|
||||||
|
|
||||||
|
void initialize(Node* N);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Looks up a type/variable and ensures that it is a monomorphic type.
|
* Looks up a type/variable and ensures that it is a monomorphic type.
|
||||||
*
|
*
|
||||||
|
@ -248,6 +250,9 @@ namespace bolt {
|
||||||
void propagateClasses(TypeclassContext& Classes, Type* Ty);
|
void propagateClasses(TypeclassContext& Classes, Type* Ty);
|
||||||
void propagateClassTycon(TypeclassId& Class, TCon* Ty);
|
void propagateClassTycon(TypeclassId& Class, TCon* Ty);
|
||||||
|
|
||||||
|
// TODO Remove this
|
||||||
|
Node* Source;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Assign a type to a unification variable.
|
* Assign a type to a unification variable.
|
||||||
*
|
*
|
||||||
|
@ -260,14 +265,6 @@ namespace bolt {
|
||||||
*/
|
*/
|
||||||
void join(TVar* A, Type* B);
|
void join(TVar* A, Type* B);
|
||||||
|
|
||||||
// Unification parameters
|
|
||||||
Type* OrigLeft;
|
|
||||||
Type* OrigRight;
|
|
||||||
TypePath LeftPath;
|
|
||||||
TypePath RightPath;
|
|
||||||
ByteString CurrentFieldName;
|
|
||||||
Node* Source;
|
|
||||||
|
|
||||||
bool unify(Type* A, Type* B);
|
bool unify(Type* A, Type* B);
|
||||||
|
|
||||||
void unifyError();
|
void unifyError();
|
||||||
|
|
223
src/Checker.cc
223
src/Checker.cc
|
@ -98,12 +98,16 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
|
|
||||||
Scheme* Checker::lookup(ByteString Name) {
|
Scheme* Checker::lookup(ByteString Name) {
|
||||||
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
|
auto Curr = &getContext();
|
||||||
auto Curr = *Iter;
|
for (;;) {
|
||||||
auto Match = Curr->Env.find(Name);
|
auto Match = Curr->Env.find(Name);
|
||||||
if (Match != Curr->Env.end()) {
|
if (Match != Curr->Env.end()) {
|
||||||
return Match->second;
|
return Match->second;
|
||||||
}
|
}
|
||||||
|
Curr = Curr->Parent;
|
||||||
|
if (!Curr) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -119,11 +123,11 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::addBinding(ByteString Name, Scheme* Scm) {
|
void Checker::addBinding(ByteString Name, Scheme* Scm) {
|
||||||
Contexts.back()->Env.emplace(Name, Scm);
|
getContext().Env.emplace(Name, Scm);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* Checker::getReturnType() {
|
Type* Checker::getReturnType() {
|
||||||
auto Ty = Contexts.back()->ReturnType;
|
auto Ty = getContext().ReturnType;
|
||||||
ZEN_ASSERT(Ty != nullptr);
|
ZEN_ASSERT(Ty != nullptr);
|
||||||
return Ty;
|
return Ty;
|
||||||
}
|
}
|
||||||
|
@ -137,24 +141,43 @@ namespace bolt {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Checker::pushContext(InferContext* Ctx) {
|
||||||
|
ActiveContext = Ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Checker::popContext() {
|
||||||
|
ZEN_ASSERT(ActiveContext);
|
||||||
|
ActiveContext = ActiveContext->Parent;
|
||||||
|
}
|
||||||
|
|
||||||
InferContext& Checker::getContext() {
|
InferContext& Checker::getContext() {
|
||||||
ZEN_ASSERT(!Contexts.empty());
|
ZEN_ASSERT(ActiveContext);
|
||||||
return *Contexts.back();
|
return *ActiveContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::addConstraint(Constraint* C) {
|
void Checker::addConstraint(Constraint* C) {
|
||||||
switch (C->getKind()) {
|
switch (C->getKind()) {
|
||||||
case ConstraintKind::Class:
|
case ConstraintKind::Class:
|
||||||
{
|
{
|
||||||
Contexts.back()->Constraints->push_back(C);
|
getContext().Constraints->push_back(C);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case ConstraintKind::Equal:
|
case ConstraintKind::Equal:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<CEqual*>(C);
|
auto Y = static_cast<CEqual*>(C);
|
||||||
|
|
||||||
|
auto Curr = &getContext();
|
||||||
|
std::vector<InferContext*> Contexts;
|
||||||
|
for (;;) {
|
||||||
|
Contexts.push_back(Curr);
|
||||||
|
Curr = Curr->Parent;
|
||||||
|
if (!Curr) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::size_t MaxLevelLeft = 0;
|
std::size_t MaxLevelLeft = 0;
|
||||||
for (std::size_t I = Contexts.size(); I-- > 0; ) {
|
for (std::size_t I = 0; I < Contexts.size(); I++) {
|
||||||
auto Ctx = Contexts[I];
|
auto Ctx = Contexts[I];
|
||||||
if (hasTypeVar(*Ctx->TVs, Y->Left)) {
|
if (hasTypeVar(*Ctx->TVs, Y->Left)) {
|
||||||
MaxLevelLeft = I;
|
MaxLevelLeft = I;
|
||||||
|
@ -162,7 +185,7 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::size_t MaxLevelRight = 0;
|
std::size_t MaxLevelRight = 0;
|
||||||
for (std::size_t I = Contexts.size(); I-- > 0; ) {
|
for (std::size_t I = 0; I < Contexts.size(); I++) {
|
||||||
auto Ctx = Contexts[I];
|
auto Ctx = Contexts[I];
|
||||||
if (hasTypeVar(*Ctx->TVs, Y->Right)) {
|
if (hasTypeVar(*Ctx->TVs, Y->Right)) {
|
||||||
MaxLevelRight = I;
|
MaxLevelRight = I;
|
||||||
|
@ -172,7 +195,7 @@ namespace bolt {
|
||||||
auto MaxLevel = std::max(MaxLevelLeft, MaxLevelRight);
|
auto MaxLevel = std::max(MaxLevelLeft, MaxLevelRight);
|
||||||
|
|
||||||
std::size_t MinLevel = MaxLevel;
|
std::size_t MinLevel = MaxLevel;
|
||||||
for (std::size_t I = 0; I < Contexts.size(); I++) {
|
for (std::size_t I = Contexts.size(); I-- > 0; ) {
|
||||||
auto Ctx = Contexts[I];
|
auto Ctx = Contexts[I];
|
||||||
if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
|
if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
|
||||||
MinLevel = I;
|
MinLevel = I;
|
||||||
|
@ -180,7 +203,6 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0
|
|
||||||
if (MaxLevel == MinLevel || MaxLevelLeft == 0 || MaxLevelRight == 0) {
|
if (MaxLevel == MinLevel || MaxLevelLeft == 0 || MaxLevelRight == 0) {
|
||||||
solveCEqual(Y);
|
solveCEqual(Y);
|
||||||
} else {
|
} else {
|
||||||
|
@ -202,10 +224,6 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::addClass(TypeclassSignature Sig) {
|
|
||||||
getContext().Classes.push_back(Sig);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Checker::forwardDeclare(Node* X) {
|
void Checker::forwardDeclare(Node* X) {
|
||||||
|
|
||||||
switch (X->getKind()) {
|
switch (X->getKind()) {
|
||||||
|
@ -269,21 +287,19 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto Decl = static_cast<VariantDeclaration*>(X);
|
auto Decl = static_cast<VariantDeclaration*>(X);
|
||||||
|
|
||||||
auto& ParentCtx = getContext();
|
pushContext(Decl->Ctx);
|
||||||
auto Ctx = createInferContext();
|
|
||||||
Contexts.push_back(Ctx);
|
|
||||||
|
|
||||||
std::vector<TVar*> Vars;
|
std::vector<TVar*> Vars;
|
||||||
for (auto TE: Decl->TVs) {
|
for (auto TE: Decl->TVs) {
|
||||||
auto TV = createRigidVar(TE->Name->getCanonicalText());
|
auto TV = createRigidVar(TE->Name->getCanonicalText());
|
||||||
Ctx->TVs->emplace(TV);
|
Decl->Ctx->TVs->emplace(TV);
|
||||||
Vars.push_back(TV);
|
Vars.push_back(TV);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* Ty = createConType(Decl->Name->getCanonicalText());
|
Type* Ty = createConType(Decl->Name->getCanonicalText());
|
||||||
|
|
||||||
// Must be added early so we can create recursive types
|
// Must be added early so we can create recursive types
|
||||||
ParentCtx.Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty));
|
Decl->Ctx->Parent->Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty));
|
||||||
|
|
||||||
for (auto Member: Decl->Members) {
|
for (auto Member: Decl->Members) {
|
||||||
switch (Member->getKind()) {
|
switch (Member->getKind()) {
|
||||||
|
@ -298,7 +314,7 @@ namespace bolt {
|
||||||
for (auto Element: TupleMember->Elements) {
|
for (auto Element: TupleMember->Elements) {
|
||||||
ParamTypes.push_back(inferTypeExpression(Element));
|
ParamTypes.push_back(inferTypeExpression(Element));
|
||||||
}
|
}
|
||||||
ParentCtx.Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Ctx->TVs, Ctx->Constraints, new TArrow(ParamTypes, RetTy)));
|
Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(ParamTypes, RetTy)));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case NodeKind::RecordVariantDeclarationMember:
|
case NodeKind::RecordVariantDeclarationMember:
|
||||||
|
@ -311,7 +327,7 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Contexts.pop_back();
|
popContext();
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -320,13 +336,12 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto Decl = static_cast<RecordDeclaration*>(X);
|
auto Decl = static_cast<RecordDeclaration*>(X);
|
||||||
|
|
||||||
auto& ParentCtx = getContext();
|
pushContext(Decl->Ctx);
|
||||||
auto Ctx = createInferContext();
|
|
||||||
Contexts.push_back(Ctx);
|
|
||||||
std::vector<TVar*> Vars;
|
std::vector<TVar*> Vars;
|
||||||
for (auto TE: Decl->Vars) {
|
for (auto TE: Decl->Vars) {
|
||||||
auto TV = createRigidVar(TE->Name->getCanonicalText());
|
auto TV = createRigidVar(TE->Name->getCanonicalText());
|
||||||
Ctx->TVs->emplace(TV);
|
Decl->Ctx->TVs->emplace(TV);
|
||||||
Vars.push_back(TV);
|
Vars.push_back(TV);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,9 +349,9 @@ namespace bolt {
|
||||||
auto Ty = createConType(Name);
|
auto Ty = createConType(Name);
|
||||||
|
|
||||||
// Must be added early so we can create recursive types
|
// Must be added early so we can create recursive types
|
||||||
ParentCtx.Env.emplace(Name, new Forall(Ty));
|
Decl->Ctx->Parent->Env.emplace(Name, new Forall(Ty));
|
||||||
|
|
||||||
// Corresponds to the logic of one branch of a VaraintDeclarationMember
|
// Corresponds to the logic of one branch of a VariantDeclarationMember
|
||||||
Type* FieldsTy = new TNil();
|
Type* FieldsTy = new TNil();
|
||||||
for (auto Field: Decl->Fields) {
|
for (auto Field: Decl->Fields) {
|
||||||
FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy);
|
FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy);
|
||||||
|
@ -345,8 +360,8 @@ namespace bolt {
|
||||||
for (auto TV: Vars) {
|
for (auto TV: Vars) {
|
||||||
RetTy = new TApp(RetTy, TV);
|
RetTy = new TApp(RetTy, TV);
|
||||||
}
|
}
|
||||||
Contexts.pop_back();
|
popContext();
|
||||||
addBinding(Name, new Forall(Ctx->TVs, Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
|
addBinding(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -358,24 +373,70 @@ namespace bolt {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Checker::initialize(Node* N) {
|
||||||
|
|
||||||
|
struct Init : public CSTVisitor<Init> {
|
||||||
|
|
||||||
|
Checker& C;
|
||||||
|
|
||||||
|
std::stack<InferContext*> Contexts;
|
||||||
|
|
||||||
|
InferContext* createDerivedContext() {
|
||||||
|
return C.createInferContext(Contexts.top());
|
||||||
|
}
|
||||||
|
|
||||||
|
void visitVariantDeclaration(VariantDeclaration* Decl) {
|
||||||
|
Decl->Ctx = createDerivedContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
void visitRecordDeclaration(RecordDeclaration* Decl) {
|
||||||
|
Decl->Ctx = createDerivedContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
void visitMatchCase(MatchCase* C) {
|
||||||
|
C->Ctx = createDerivedContext();
|
||||||
|
Contexts.push(C->Ctx);
|
||||||
|
visitEachChild(C);
|
||||||
|
Contexts.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
void visitSourceFile(SourceFile* SF) {
|
||||||
|
SF->Ctx = C.createInferContext();
|
||||||
|
Contexts.push(SF->Ctx);
|
||||||
|
visitEachChild(SF);
|
||||||
|
Contexts.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
void visitLetDeclaration(LetDeclaration* Let) {
|
||||||
|
if (Let->isFunc()) {
|
||||||
|
Let->Ctx = createDerivedContext();
|
||||||
|
Contexts.push(Let->Ctx);
|
||||||
|
visitEachChild(Let);
|
||||||
|
Contexts.pop();
|
||||||
|
} else {
|
||||||
|
Let->Ctx = Contexts.top();
|
||||||
|
visitEachChild(Let);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
Init I { {}, *this };
|
||||||
|
I.visit(N);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
void Checker::forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints) {
|
void Checker::forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints) {
|
||||||
|
|
||||||
auto Let = static_cast<LetDeclaration*>(N);
|
auto Let = static_cast<LetDeclaration*>(N);
|
||||||
bool IsFunc = !Let->Params.empty();
|
|
||||||
bool IsInstance = llvm::isa<InstanceDeclaration>(Let->Parent);
|
|
||||||
bool IsClass = llvm::isa<ClassDeclaration>(Let->Parent);
|
|
||||||
bool HasContext = IsFunc || IsInstance || IsClass;
|
|
||||||
|
|
||||||
if (HasContext) {
|
pushContext(Let->Ctx);
|
||||||
Let->Ctx = createInferContext(TVs, Constraints);
|
|
||||||
Contexts.push_back(Let->Ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If declaring a let-declaration inside a type class declaration,
|
// If declaring a let-declaration inside a type class declaration,
|
||||||
// we need to mark that the let-declaration requires this class.
|
// we need to mark that the let-declaration requires this class.
|
||||||
// This marking is set on the rigid type variables of the class, which
|
// This marking is set on the rigid type variables of the class, which
|
||||||
// are then added to this local type environment.
|
// are then added to this local type environment.
|
||||||
if (IsClass) {
|
if (Let->isClass()) {
|
||||||
auto Class = static_cast<ClassDeclaration*>(Let->Parent);
|
auto Class = static_cast<ClassDeclaration*>(Let->Parent);
|
||||||
for (auto TE: Class->TypeVars) {
|
for (auto TE: Class->TypeVars) {
|
||||||
auto TV = llvm::cast<TVar>(TE->getType());
|
auto TV = llvm::cast<TVar>(TE->getType());
|
||||||
|
@ -400,7 +461,7 @@ namespace bolt {
|
||||||
// we need to perform some work to make sure the type asserts of the
|
// we need to perform some work to make sure the type asserts of the
|
||||||
// corresponding let-declaration in the type class declaration are
|
// corresponding let-declaration in the type class declaration are
|
||||||
// accounted for.
|
// accounted for.
|
||||||
if (IsInstance) {
|
if (Let->isInstance()) {
|
||||||
|
|
||||||
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
|
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
|
||||||
auto Class = llvm::cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
|
auto Class = llvm::cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
|
||||||
|
@ -443,7 +504,7 @@ namespace bolt {
|
||||||
case NodeKind::LetBlockBody:
|
case NodeKind::LetBlockBody:
|
||||||
{
|
{
|
||||||
auto Block = static_cast<LetBlockBody*>(Let->Body);
|
auto Block = static_cast<LetBlockBody*>(Let->Body);
|
||||||
if (IsFunc) {
|
if (Let->isFunc()) {
|
||||||
Let->Ctx->ReturnType = createTypeVar();
|
Let->Ctx->ReturnType = createTypeVar();
|
||||||
}
|
}
|
||||||
for (auto Element: Block->Elements) {
|
for (auto Element: Block->Elements) {
|
||||||
|
@ -456,9 +517,10 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
popContext();
|
||||||
|
|
||||||
Type* BindTy;
|
Type* BindTy;
|
||||||
if (HasContext) {
|
if (Let->isFunc()) {
|
||||||
Contexts.pop_back();
|
|
||||||
BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs);
|
BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs);
|
||||||
} else {
|
} else {
|
||||||
BindTy = inferPattern(Let->Pattern);
|
BindTy = inferPattern(Let->Pattern);
|
||||||
|
@ -467,17 +529,9 @@ namespace bolt {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::inferLetDeclaration(LetDeclaration* N) {
|
void Checker::inferLetDeclaration(LetDeclaration* Decl) {
|
||||||
|
|
||||||
auto Decl = static_cast<LetDeclaration*>(N);
|
pushContext(Decl->Ctx);
|
||||||
bool IsFunc = !Decl->Params.empty();
|
|
||||||
bool IsInstance = llvm::isa<InstanceDeclaration>(Decl->Parent);
|
|
||||||
bool IsClass = llvm::isa<ClassDeclaration>(Decl->Parent);
|
|
||||||
bool HasContext = IsFunc || IsInstance || IsClass;
|
|
||||||
|
|
||||||
if (HasContext) {
|
|
||||||
Contexts.push_back(Decl->Ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Type*> ParamTypes;
|
std::vector<Type*> ParamTypes;
|
||||||
Type* RetType;
|
Type* RetType;
|
||||||
|
@ -498,7 +552,7 @@ namespace bolt {
|
||||||
case NodeKind::LetBlockBody:
|
case NodeKind::LetBlockBody:
|
||||||
{
|
{
|
||||||
auto Block = static_cast<LetBlockBody*>(Decl->Body);
|
auto Block = static_cast<LetBlockBody*>(Decl->Body);
|
||||||
ZEN_ASSERT(HasContext);
|
ZEN_ASSERT(Decl->isFunc());
|
||||||
RetType = Decl->Ctx->ReturnType;
|
RetType = Decl->Ctx->ReturnType;
|
||||||
for (auto Element: Block->Elements) {
|
for (auto Element: Block->Elements) {
|
||||||
infer(Element);
|
infer(Element);
|
||||||
|
@ -512,15 +566,13 @@ namespace bolt {
|
||||||
RetType = createTypeVar();
|
RetType = createTypeVar();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (HasContext) {
|
popContext();
|
||||||
Contexts.pop_back();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (IsFunc) {
|
if (Decl->isFunc()) {
|
||||||
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), N });
|
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl });
|
||||||
} else {
|
} else {
|
||||||
// Declaration is a plain (typed) variable
|
// Declaration is a plain (typed) variable
|
||||||
addConstraint(new CEqual { Decl->Ty, RetType, N });
|
addConstraint(new CEqual { Decl->Ty, RetType, Decl });
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -611,18 +663,19 @@ namespace bolt {
|
||||||
|
|
||||||
TVarRigid* Checker::createRigidVar(ByteString Name) {
|
TVarRigid* Checker::createRigidVar(ByteString Name) {
|
||||||
auto TV = new TVarRigid(NextTypeVarId++, Name);
|
auto TV = new TVarRigid(NextTypeVarId++, Name);
|
||||||
Contexts.back()->TVs->emplace(TV);
|
getContext().TVs->emplace(TV);
|
||||||
return TV;
|
return TV;
|
||||||
}
|
}
|
||||||
|
|
||||||
TVar* Checker::createTypeVar() {
|
TVar* Checker::createTypeVar() {
|
||||||
auto TV = new TVar(NextTypeVarId++, VarKind::Unification);
|
auto TV = new TVar(NextTypeVarId++, VarKind::Unification);
|
||||||
Contexts.back()->TVs->emplace(TV);
|
getContext().TVs->emplace(TV);
|
||||||
return TV;
|
return TV;
|
||||||
}
|
}
|
||||||
|
|
||||||
InferContext* Checker::createInferContext(TVSet* TVs, ConstraintSet* Constraints) {
|
InferContext* Checker::createInferContext(InferContext* Parent, TVSet* TVs, ConstraintSet* Constraints) {
|
||||||
auto Ctx = new InferContext;
|
auto Ctx = new InferContext;
|
||||||
|
Ctx->Parent = Parent;
|
||||||
Ctx->TVs = new TVSet;
|
Ctx->TVs = new TVSet;
|
||||||
Ctx->Constraints = new ConstraintSet;
|
Ctx->Constraints = new ConstraintSet;
|
||||||
return Ctx;
|
return Ctx;
|
||||||
|
@ -813,13 +866,12 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
Ty = createTypeVar();
|
Ty = createTypeVar();
|
||||||
for (auto Case: Match->Cases) {
|
for (auto Case: Match->Cases) {
|
||||||
auto NewCtx = createInferContext();
|
pushContext(Case->Ctx);
|
||||||
Contexts.push_back(NewCtx);
|
|
||||||
auto PattTy = inferPattern(Case->Pattern);
|
auto PattTy = inferPattern(Case->Pattern);
|
||||||
addConstraint(new CEqual(PattTy, ValTy, X));
|
addConstraint(new CEqual(PattTy, ValTy, X));
|
||||||
auto ExprTy = inferExpression(Case->Expression);
|
auto ExprTy = inferExpression(Case->Expression);
|
||||||
addConstraint(new CEqual(ExprTy, Ty, Case->Expression));
|
addConstraint(new CEqual(ExprTy, Ty, Case->Expression));
|
||||||
Contexts.pop_back();
|
popContext();
|
||||||
}
|
}
|
||||||
if (!Match->Value) {
|
if (!Match->Value) {
|
||||||
Ty = new TArrow({ ValTy }, Ty);
|
Ty = new TArrow({ ValTy }, Ty);
|
||||||
|
@ -1036,20 +1088,25 @@ namespace bolt {
|
||||||
auto Y = static_cast<ReferenceExpression*>(N);
|
auto Y = static_cast<ReferenceExpression*>(N);
|
||||||
auto Def = Y->getScope()->lookup(Y->getSymbolPath());
|
auto Def = Y->getScope()->lookup(Y->getSymbolPath());
|
||||||
// Name lookup failures will be reported directly in inferExpression().
|
// Name lookup failures will be reported directly in inferExpression().
|
||||||
// Parameters are clearly no let-decarations. They never have their own
|
if (Def == nullptr || Def->getKind() == NodeKind::SourceFile) {
|
||||||
// inference context, so we have to skip them.
|
|
||||||
if (Def == nullptr || Def->getKind() == NodeKind::Parameter) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration || Def->getKind() == NodeKind::SourceFile);
|
// This case ensures that a deeply nested structure that references a
|
||||||
RefGraph.addEdge(Stack.top(), Def);
|
// parameter of a parent node but is not referenced itself is correctly handled.
|
||||||
|
// Note that the edge goes from the parent let to the parameter. This is normal.
|
||||||
|
if (Def->getKind() == NodeKind::Parameter) {
|
||||||
|
RefGraph.addEdge(Stack.top(), Def->Parent);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration);
|
||||||
|
if (!Stack.empty()) {
|
||||||
|
RefGraph.addEdge(Def, Stack.top());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
RefGraph.addVertex(SF);
|
|
||||||
Visitor V { {}, RefGraph };
|
Visitor V { {}, RefGraph };
|
||||||
V.Stack.push(SF);
|
|
||||||
V.visit(SF);
|
V.visit(SF);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1189,8 +1246,8 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::check(SourceFile *SF) {
|
void Checker::check(SourceFile *SF) {
|
||||||
auto RootContext = createInferContext();
|
initialize(SF);
|
||||||
Contexts.push_back(RootContext);
|
pushContext(SF->Ctx);
|
||||||
addBinding("String", new Forall(StringType));
|
addBinding("String", new Forall(StringType));
|
||||||
addBinding("Int", new Forall(IntType));
|
addBinding("Int", new Forall(IntType));
|
||||||
addBinding("Bool", new Forall(BoolType));
|
addBinding("Bool", new Forall(BoolType));
|
||||||
|
@ -1206,9 +1263,6 @@ namespace bolt {
|
||||||
forwardDeclare(SF);
|
forwardDeclare(SF);
|
||||||
auto SCCs = RefGraph.strongconnect();
|
auto SCCs = RefGraph.strongconnect();
|
||||||
for (auto Nodes: SCCs) {
|
for (auto Nodes: SCCs) {
|
||||||
if (Nodes.size() == 1 && llvm::isa<SourceFile>(Nodes[0])) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto TVs = new TVSet;
|
auto TVs = new TVSet;
|
||||||
auto Constraints = new ConstraintSet;
|
auto Constraints = new ConstraintSet;
|
||||||
for (auto N: Nodes) {
|
for (auto N: Nodes) {
|
||||||
|
@ -1217,9 +1271,6 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto Nodes: SCCs) {
|
for (auto Nodes: SCCs) {
|
||||||
if (Nodes.size() == 1 && llvm::isa<SourceFile>(Nodes[0])) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (auto N: Nodes) {
|
for (auto N: Nodes) {
|
||||||
auto Decl = static_cast<LetDeclaration*>(N);
|
auto Decl = static_cast<LetDeclaration*>(N);
|
||||||
Decl->IsCycleActive = true;
|
Decl->IsCycleActive = true;
|
||||||
|
@ -1234,8 +1285,8 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
infer(SF);
|
infer(SF);
|
||||||
Contexts.pop_back();
|
popContext();
|
||||||
solve(new CMany(*RootContext->Constraints));
|
solve(new CMany(*SF->Ctx->Constraints));
|
||||||
checkTypeclassSigs(SF);
|
checkTypeclassSigs(SF);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1349,6 +1400,8 @@ namespace bolt {
|
||||||
|
|
||||||
void Checker::join(TVar* TV, Type* Ty) {
|
void Checker::join(TVar* TV, Type* Ty) {
|
||||||
|
|
||||||
|
// std::cerr << describe(TV) << " => " << describe(Ty) << std::endl;
|
||||||
|
|
||||||
TV->set(Ty);
|
TV->set(Ty);
|
||||||
|
|
||||||
propagateClasses(TV->Contexts, Ty);
|
propagateClasses(TV->Contexts, Ty);
|
||||||
|
@ -1364,9 +1417,9 @@ namespace bolt {
|
||||||
// Should it get assigned another unification variable, that's OK too
|
// Should it get assigned another unification variable, that's OK too
|
||||||
// because then that variable is what matters and it will become the new
|
// because then that variable is what matters and it will become the new
|
||||||
// (possibly polymorphic) variable.
|
// (possibly polymorphic) variable.
|
||||||
if (!Contexts.empty()) {
|
if (ActiveContext) {
|
||||||
// std::cerr << "erase " << describe(TV) << std::endl;
|
// std::cerr << "erase " << describe(TV) << std::endl;
|
||||||
auto TVs = Contexts.back()->TVs;
|
auto TVs = ActiveContext->TVs;
|
||||||
TVs->erase(TV);
|
TVs->erase(TV);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue