Make InferContext have a parent context

This commit is contained in:
Sam Vervaeck 2023-05-30 13:37:47 +02:00
parent 83e89f4e8c
commit a8f8658f27
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
3 changed files with 180 additions and 109 deletions

View file

@ -17,6 +17,7 @@
namespace bolt {
class Type;
class InferContext;
class Token;
class SourceFile;
@ -1279,6 +1280,8 @@ namespace bolt {
class MatchCase : public Node {
public:
InferContext* Ctx;
class Pattern* Pattern;
class RArrowAlt* RArrowAlt;
class Expression* Expression;
@ -1658,9 +1661,6 @@ namespace bolt {
};
class Type;
class InferContext;
class LetDeclaration : public Node {
Scope* TheScope = nullptr;
@ -1703,6 +1703,22 @@ namespace bolt {
return TheScope;
}
bool isFunc() const noexcept {
return !Params.empty();
}
bool isVar() const noexcept {
return !isFunc();
}
bool isInstance() const noexcept {
return Parent->getKind() == NodeKind::InstanceDeclaration;
}
bool isClass() const noexcept {
return Parent->getKind() == NodeKind::ClassDeclaration;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
@ -1801,6 +1817,8 @@ namespace bolt {
class RecordDeclaration : public Node {
public:
InferContext* Ctx;
class PubKeyword* PubKeyword;
class StructKeyword* StructKeyword;
IdentifierAlt* Name;
@ -1878,6 +1896,8 @@ namespace bolt {
class VariantDeclaration : public Node {
public:
InferContext* Ctx;
class PubKeyword* PubKeyword;
class EnumKeyword* EnumKeyword;
class IdentifierAlt* Name;
@ -1912,6 +1932,7 @@ namespace bolt {
public:
TextFile& File;
InferContext* Ctx;
std::vector<Node*> Elements;

View file

@ -157,10 +157,8 @@ namespace bolt {
TypeEnv Env;
Type* ReturnType = nullptr;
std::vector<TypeclassSignature> Classes;
//inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env, Type* ReturnType):
// Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env), ReturnType(ReturnType) {}
InferContext* Parent = nullptr;
};
@ -183,22 +181,20 @@ namespace bolt {
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
std::vector<InferContext*> Contexts;
// std::vector<InferContext*> Contexts;
InferContext* ActiveContext;
InferContext& getContext();
void pushContext(InferContext* Ctx);
void popContext();
/**
* The queue that is used during solving to store any unsolved constraints.
*/
std::deque<class Constraint*> Queue;
/**
* Pointer to the current constraint being unified.
*/
CEqual* C;
InferContext& getContext();
void addConstraint(Constraint* Constraint);
void addClass(TypeclassSignature Sig);
void forwardDeclare(Node* Node);
void forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
@ -217,12 +213,18 @@ namespace bolt {
TCon* createConType(ByteString Name);
TVar* createTypeVar();
TVarRigid* createRigidVar(ByteString Name);
InferContext* createInferContext(TVSet* TVs = new TVSet, ConstraintSet* Constraints = new ConstraintSet);
InferContext* createInferContext(
InferContext* Parent = nullptr,
TVSet* TVs = new TVSet,
ConstraintSet* Constraints = new ConstraintSet
);
void addBinding(ByteString Name, Scheme* Scm);
Scheme* lookup(ByteString Name);
void initialize(Node* N);
/**
* Looks up a type/variable and ensures that it is a monomorphic type.
*
@ -248,6 +250,9 @@ namespace bolt {
void propagateClasses(TypeclassContext& Classes, Type* Ty);
void propagateClassTycon(TypeclassId& Class, TCon* Ty);
// TODO Remove this
Node* Source;
/**
* Assign a type to a unification variable.
*
@ -260,14 +265,6 @@ namespace bolt {
*/
void join(TVar* A, Type* B);
// Unification parameters
Type* OrigLeft;
Type* OrigRight;
TypePath LeftPath;
TypePath RightPath;
ByteString CurrentFieldName;
Node* Source;
bool unify(Type* A, Type* B);
void unifyError();

View file

@ -98,12 +98,16 @@ namespace bolt {
}
Scheme* Checker::lookup(ByteString Name) {
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
auto Curr = *Iter;
auto Curr = &getContext();
for (;;) {
auto Match = Curr->Env.find(Name);
if (Match != Curr->Env.end()) {
return Match->second;
}
Curr = Curr->Parent;
if (!Curr) {
break;
}
}
return nullptr;
}
@ -119,11 +123,11 @@ namespace bolt {
}
void Checker::addBinding(ByteString Name, Scheme* Scm) {
Contexts.back()->Env.emplace(Name, Scm);
getContext().Env.emplace(Name, Scm);
}
Type* Checker::getReturnType() {
auto Ty = Contexts.back()->ReturnType;
auto Ty = getContext().ReturnType;
ZEN_ASSERT(Ty != nullptr);
return Ty;
}
@ -137,24 +141,43 @@ namespace bolt {
return false;
}
void Checker::pushContext(InferContext* Ctx) {
ActiveContext = Ctx;
}
void Checker::popContext() {
ZEN_ASSERT(ActiveContext);
ActiveContext = ActiveContext->Parent;
}
InferContext& Checker::getContext() {
ZEN_ASSERT(!Contexts.empty());
return *Contexts.back();
ZEN_ASSERT(ActiveContext);
return *ActiveContext;
}
void Checker::addConstraint(Constraint* C) {
switch (C->getKind()) {
case ConstraintKind::Class:
{
Contexts.back()->Constraints->push_back(C);
getContext().Constraints->push_back(C);
break;
}
case ConstraintKind::Equal:
{
auto Y = static_cast<CEqual*>(C);
auto Curr = &getContext();
std::vector<InferContext*> Contexts;
for (;;) {
Contexts.push_back(Curr);
Curr = Curr->Parent;
if (!Curr) {
break;
}
}
std::size_t MaxLevelLeft = 0;
for (std::size_t I = Contexts.size(); I-- > 0; ) {
for (std::size_t I = 0; I < Contexts.size(); I++) {
auto Ctx = Contexts[I];
if (hasTypeVar(*Ctx->TVs, Y->Left)) {
MaxLevelLeft = I;
@ -162,7 +185,7 @@ namespace bolt {
}
}
std::size_t MaxLevelRight = 0;
for (std::size_t I = Contexts.size(); I-- > 0; ) {
for (std::size_t I = 0; I < Contexts.size(); I++) {
auto Ctx = Contexts[I];
if (hasTypeVar(*Ctx->TVs, Y->Right)) {
MaxLevelRight = I;
@ -172,7 +195,7 @@ namespace bolt {
auto MaxLevel = std::max(MaxLevelLeft, MaxLevelRight);
std::size_t MinLevel = MaxLevel;
for (std::size_t I = 0; I < Contexts.size(); I++) {
for (std::size_t I = Contexts.size(); I-- > 0; ) {
auto Ctx = Contexts[I];
if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
MinLevel = I;
@ -180,7 +203,6 @@ namespace bolt {
}
}
// TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0
if (MaxLevel == MinLevel || MaxLevelLeft == 0 || MaxLevelRight == 0) {
solveCEqual(Y);
} else {
@ -202,10 +224,6 @@ namespace bolt {
}
}
void Checker::addClass(TypeclassSignature Sig) {
getContext().Classes.push_back(Sig);
}
void Checker::forwardDeclare(Node* X) {
switch (X->getKind()) {
@ -269,21 +287,19 @@ namespace bolt {
{
auto Decl = static_cast<VariantDeclaration*>(X);
auto& ParentCtx = getContext();
auto Ctx = createInferContext();
Contexts.push_back(Ctx);
pushContext(Decl->Ctx);
std::vector<TVar*> Vars;
for (auto TE: Decl->TVs) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Ctx->TVs->emplace(TV);
Decl->Ctx->TVs->emplace(TV);
Vars.push_back(TV);
}
Type* Ty = createConType(Decl->Name->getCanonicalText());
// Must be added early so we can create recursive types
ParentCtx.Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty));
Decl->Ctx->Parent->Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty));
for (auto Member: Decl->Members) {
switch (Member->getKind()) {
@ -298,7 +314,7 @@ namespace bolt {
for (auto Element: TupleMember->Elements) {
ParamTypes.push_back(inferTypeExpression(Element));
}
ParentCtx.Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Ctx->TVs, Ctx->Constraints, new TArrow(ParamTypes, RetTy)));
Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(ParamTypes, RetTy)));
break;
}
case NodeKind::RecordVariantDeclarationMember:
@ -311,7 +327,7 @@ namespace bolt {
}
}
Contexts.pop_back();
popContext();
break;
}
@ -320,13 +336,12 @@ namespace bolt {
{
auto Decl = static_cast<RecordDeclaration*>(X);
auto& ParentCtx = getContext();
auto Ctx = createInferContext();
Contexts.push_back(Ctx);
pushContext(Decl->Ctx);
std::vector<TVar*> Vars;
for (auto TE: Decl->Vars) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Ctx->TVs->emplace(TV);
Decl->Ctx->TVs->emplace(TV);
Vars.push_back(TV);
}
@ -334,9 +349,9 @@ namespace bolt {
auto Ty = createConType(Name);
// Must be added early so we can create recursive types
ParentCtx.Env.emplace(Name, new Forall(Ty));
Decl->Ctx->Parent->Env.emplace(Name, new Forall(Ty));
// Corresponds to the logic of one branch of a VaraintDeclarationMember
// Corresponds to the logic of one branch of a VariantDeclarationMember
Type* FieldsTy = new TNil();
for (auto Field: Decl->Fields) {
FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy);
@ -345,8 +360,8 @@ namespace bolt {
for (auto TV: Vars) {
RetTy = new TApp(RetTy, TV);
}
Contexts.pop_back();
addBinding(Name, new Forall(Ctx->TVs, Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
popContext();
addBinding(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
break;
}
@ -358,24 +373,70 @@ namespace bolt {
}
void Checker::initialize(Node* N) {
struct Init : public CSTVisitor<Init> {
Checker& C;
std::stack<InferContext*> Contexts;
InferContext* createDerivedContext() {
return C.createInferContext(Contexts.top());
}
void visitVariantDeclaration(VariantDeclaration* Decl) {
Decl->Ctx = createDerivedContext();
}
void visitRecordDeclaration(RecordDeclaration* Decl) {
Decl->Ctx = createDerivedContext();
}
void visitMatchCase(MatchCase* C) {
C->Ctx = createDerivedContext();
Contexts.push(C->Ctx);
visitEachChild(C);
Contexts.pop();
}
void visitSourceFile(SourceFile* SF) {
SF->Ctx = C.createInferContext();
Contexts.push(SF->Ctx);
visitEachChild(SF);
Contexts.pop();
}
void visitLetDeclaration(LetDeclaration* Let) {
if (Let->isFunc()) {
Let->Ctx = createDerivedContext();
Contexts.push(Let->Ctx);
visitEachChild(Let);
Contexts.pop();
} else {
Let->Ctx = Contexts.top();
visitEachChild(Let);
}
}
};
Init I { {}, *this };
I.visit(N);
}
void Checker::forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints) {
auto Let = static_cast<LetDeclaration*>(N);
bool IsFunc = !Let->Params.empty();
bool IsInstance = llvm::isa<InstanceDeclaration>(Let->Parent);
bool IsClass = llvm::isa<ClassDeclaration>(Let->Parent);
bool HasContext = IsFunc || IsInstance || IsClass;
if (HasContext) {
Let->Ctx = createInferContext(TVs, Constraints);
Contexts.push_back(Let->Ctx);
}
pushContext(Let->Ctx);
// If declaring a let-declaration inside a type class declaration,
// we need to mark that the let-declaration requires this class.
// This marking is set on the rigid type variables of the class, which
// are then added to this local type environment.
if (IsClass) {
if (Let->isClass()) {
auto Class = static_cast<ClassDeclaration*>(Let->Parent);
for (auto TE: Class->TypeVars) {
auto TV = llvm::cast<TVar>(TE->getType());
@ -400,7 +461,7 @@ namespace bolt {
// we need to perform some work to make sure the type asserts of the
// corresponding let-declaration in the type class declaration are
// accounted for.
if (IsInstance) {
if (Let->isInstance()) {
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
auto Class = llvm::cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
@ -443,7 +504,7 @@ namespace bolt {
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(Let->Body);
if (IsFunc) {
if (Let->isFunc()) {
Let->Ctx->ReturnType = createTypeVar();
}
for (auto Element: Block->Elements) {
@ -456,9 +517,10 @@ namespace bolt {
}
}
popContext();
Type* BindTy;
if (HasContext) {
Contexts.pop_back();
if (Let->isFunc()) {
BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs);
} else {
BindTy = inferPattern(Let->Pattern);
@ -467,17 +529,9 @@ namespace bolt {
}
void Checker::inferLetDeclaration(LetDeclaration* N) {
void Checker::inferLetDeclaration(LetDeclaration* Decl) {
auto Decl = static_cast<LetDeclaration*>(N);
bool IsFunc = !Decl->Params.empty();
bool IsInstance = llvm::isa<InstanceDeclaration>(Decl->Parent);
bool IsClass = llvm::isa<ClassDeclaration>(Decl->Parent);
bool HasContext = IsFunc || IsInstance || IsClass;
if (HasContext) {
Contexts.push_back(Decl->Ctx);
}
pushContext(Decl->Ctx);
std::vector<Type*> ParamTypes;
Type* RetType;
@ -498,7 +552,7 @@ namespace bolt {
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(Decl->Body);
ZEN_ASSERT(HasContext);
ZEN_ASSERT(Decl->isFunc());
RetType = Decl->Ctx->ReturnType;
for (auto Element: Block->Elements) {
infer(Element);
@ -512,15 +566,13 @@ namespace bolt {
RetType = createTypeVar();
}
if (HasContext) {
Contexts.pop_back();
}
popContext();
if (IsFunc) {
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), N });
if (Decl->isFunc()) {
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl });
} else {
// Declaration is a plain (typed) variable
addConstraint(new CEqual { Decl->Ty, RetType, N });
addConstraint(new CEqual { Decl->Ty, RetType, Decl });
}
}
@ -611,18 +663,19 @@ namespace bolt {
TVarRigid* Checker::createRigidVar(ByteString Name) {
auto TV = new TVarRigid(NextTypeVarId++, Name);
Contexts.back()->TVs->emplace(TV);
getContext().TVs->emplace(TV);
return TV;
}
TVar* Checker::createTypeVar() {
auto TV = new TVar(NextTypeVarId++, VarKind::Unification);
Contexts.back()->TVs->emplace(TV);
getContext().TVs->emplace(TV);
return TV;
}
InferContext* Checker::createInferContext(TVSet* TVs, ConstraintSet* Constraints) {
InferContext* Checker::createInferContext(InferContext* Parent, TVSet* TVs, ConstraintSet* Constraints) {
auto Ctx = new InferContext;
Ctx->Parent = Parent;
Ctx->TVs = new TVSet;
Ctx->Constraints = new ConstraintSet;
return Ctx;
@ -813,13 +866,12 @@ namespace bolt {
}
Ty = createTypeVar();
for (auto Case: Match->Cases) {
auto NewCtx = createInferContext();
Contexts.push_back(NewCtx);
pushContext(Case->Ctx);
auto PattTy = inferPattern(Case->Pattern);
addConstraint(new CEqual(PattTy, ValTy, X));
auto ExprTy = inferExpression(Case->Expression);
addConstraint(new CEqual(ExprTy, Ty, Case->Expression));
Contexts.pop_back();
popContext();
}
if (!Match->Value) {
Ty = new TArrow({ ValTy }, Ty);
@ -1036,20 +1088,25 @@ namespace bolt {
auto Y = static_cast<ReferenceExpression*>(N);
auto Def = Y->getScope()->lookup(Y->getSymbolPath());
// Name lookup failures will be reported directly in inferExpression().
// Parameters are clearly no let-decarations. They never have their own
// inference context, so we have to skip them.
if (Def == nullptr || Def->getKind() == NodeKind::Parameter) {
if (Def == nullptr || Def->getKind() == NodeKind::SourceFile) {
return;
}
ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration || Def->getKind() == NodeKind::SourceFile);
RefGraph.addEdge(Stack.top(), Def);
// This case ensures that a deeply nested structure that references a
// parameter of a parent node but is not referenced itself is correctly handled.
// Note that the edge goes from the parent let to the parameter. This is normal.
if (Def->getKind() == NodeKind::Parameter) {
RefGraph.addEdge(Stack.top(), Def->Parent);
return;
}
ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration);
if (!Stack.empty()) {
RefGraph.addEdge(Def, Stack.top());
}
}
};
RefGraph.addVertex(SF);
Visitor V { {}, RefGraph };
V.Stack.push(SF);
V.visit(SF);
}
@ -1189,8 +1246,8 @@ namespace bolt {
}
void Checker::check(SourceFile *SF) {
auto RootContext = createInferContext();
Contexts.push_back(RootContext);
initialize(SF);
pushContext(SF->Ctx);
addBinding("String", new Forall(StringType));
addBinding("Int", new Forall(IntType));
addBinding("Bool", new Forall(BoolType));
@ -1206,9 +1263,6 @@ namespace bolt {
forwardDeclare(SF);
auto SCCs = RefGraph.strongconnect();
for (auto Nodes: SCCs) {
if (Nodes.size() == 1 && llvm::isa<SourceFile>(Nodes[0])) {
continue;
}
auto TVs = new TVSet;
auto Constraints = new ConstraintSet;
for (auto N: Nodes) {
@ -1217,9 +1271,6 @@ namespace bolt {
}
}
for (auto Nodes: SCCs) {
if (Nodes.size() == 1 && llvm::isa<SourceFile>(Nodes[0])) {
continue;
}
for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N);
Decl->IsCycleActive = true;
@ -1234,8 +1285,8 @@ namespace bolt {
}
}
infer(SF);
Contexts.pop_back();
solve(new CMany(*RootContext->Constraints));
popContext();
solve(new CMany(*SF->Ctx->Constraints));
checkTypeclassSigs(SF);
}
@ -1349,6 +1400,8 @@ namespace bolt {
void Checker::join(TVar* TV, Type* Ty) {
// std::cerr << describe(TV) << " => " << describe(Ty) << std::endl;
TV->set(Ty);
propagateClasses(TV->Contexts, Ty);
@ -1364,9 +1417,9 @@ namespace bolt {
// Should it get assigned another unification variable, that's OK too
// because then that variable is what matters and it will become the new
// (possibly polymorphic) variable.
if (!Contexts.empty()) {
if (ActiveContext) {
// std::cerr << "erase " << describe(TV) << std::endl;
auto TVs = Contexts.back()->TVs;
auto TVs = ActiveContext->TVs;
TVs->erase(TV);
}