diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index 59e1ed598..753826fbe 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -160,9 +160,9 @@ namespace bolt { case NodeKind::LetDeclaration: return static_cast(this)->visitLetDeclaration(static_cast(N)); case NodeKind::RecordDeclarationField: - return static_cast(this)->visitStructDeclarationField(static_cast(N)); + return static_cast(this)->visitRecordDeclarationField(static_cast(N)); case NodeKind::RecordDeclaration: - return static_cast(this)->visitStructDeclaration(static_cast(N)); + return static_cast(this)->visitRecordDeclaration(static_cast(N)); case NodeKind::VariantDeclaration: return static_cast(this)->visitVariantDeclaration(static_cast(N)); case NodeKind::TupleVariantDeclarationMember: @@ -504,11 +504,11 @@ namespace bolt { visitNode(N); } - void visitStructDeclarationField(RecordDeclarationField* N) { + void visitRecordDeclarationField(RecordDeclarationField* N) { visitNode(N); } - void visitStructDeclaration(RecordDeclaration* N) { + void visitRecordDeclaration(RecordDeclaration* N) { visitNode(N); } diff --git a/src/CST.cc b/src/CST.cc index fdd0cc129..8856a8ea8 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -89,6 +89,7 @@ namespace bolt { switch (X->getKind()) { case NodeKind::LetExprBody: case NodeKind::ExpressionStatement: + case NodeKind::IfStatement: case NodeKind::ReturnStatement: break; case NodeKind::LetBlockBody: @@ -99,6 +100,9 @@ namespace bolt { } break; } + case NodeKind::InstanceDeclaration: + // We ignore let-declarations inside instance-declarations for now + break; case NodeKind::ClassDeclaration: { auto Decl = static_cast(X); diff --git a/src/Checker.cc b/src/Checker.cc index 594753c2f..66291dd29 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -166,6 +166,9 @@ namespace bolt { { auto Y = static_cast(C); + // This will store all inference contexts in Contexts, from most local + // one to most general one. Because this order is not ideal, the code + // below will have to handle that. auto Curr = &getContext(); std::vector Contexts; for (;;) { @@ -176,28 +179,42 @@ namespace bolt { } } + // If no MaxLevelLeft was found, that means that not a single + // corresponding type variable was found in the contexts. We set it to + // 0, which corresponds to the global inference context. std::size_t MaxLevelLeft = 0; for (std::size_t I = 0; I < Contexts.size(); I++) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Left)) { - MaxLevelLeft = I; + MaxLevelLeft = Contexts.size() - I - 1; break; } } + + // Same as above but now mirrored for Y->Right std::size_t MaxLevelRight = 0; for (std::size_t I = 0; I < Contexts.size(); I++) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Right)) { - MaxLevelRight = I; + MaxLevelRight = Contexts.size() - I - 1; break; } } + + // The lowest index is determined by the one that has no type variables + // in Y->Left AND in Y->Right. This implies max() must be used, so that + // the very first enounter of a type variable matters. auto MaxLevel = std::max(MaxLevelLeft, MaxLevelRight); + // Now find the highest index I such that all the contexts that are more + // local do not contain any type variables that are present in the + // equality constraint. std::size_t MinLevel = MaxLevel; for (std::size_t I = Contexts.size(); I-- > 0; ) { auto Ctx = Contexts[I]; if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { + // No need to reverse because even though Contexts is reversed, we + // are also iterating in reverse. MinLevel = I; break; } @@ -206,7 +223,7 @@ namespace bolt { if (MaxLevel == MinLevel || MaxLevelLeft == 0 || MaxLevelRight == 0) { solveCEqual(Y); } else { - Contexts[MaxLevel]->Constraints->push_back(C); + Contexts[Contexts.size() - MaxLevel - 1]->Constraints->push_back(C); } break; @@ -341,7 +358,6 @@ namespace bolt { std::vector Vars; for (auto TE: Decl->Vars) { auto TV = createRigidVar(TE->Name->getCanonicalText()); - Decl->Ctx->TVs->emplace(TV); Vars.push_back(TV); } @@ -517,10 +533,10 @@ namespace bolt { } } - popContext(); Type* BindTy; if (Let->isFunc()) { + popContext(); BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs); } else { BindTy = inferPattern(Let->Pattern); @@ -566,9 +582,8 @@ namespace bolt { RetType = createTypeVar(); } - popContext(); - if (Decl->isFunc()) { + popContext(); addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl }); } else { // Declaration is a plain (typed) variable @@ -698,12 +713,20 @@ namespace bolt { for (auto Constraint: *F->Constraints) { + // FIXME improve this + if (Constraint->getKind() == ConstraintKind::Equal) { + auto Eq = static_cast(Constraint); + Eq->Left = simplifyType(Eq->Left); + Eq->Right = simplifyType(Eq->Right); + } + auto NewConstraint = Constraint->substitute(Sub); // This makes error messages prettier by relating the typing failure // to the call site rather than the definition. if (NewConstraint->getKind() == ConstraintKind::Equal) { - static_cast(NewConstraint)->Source = Source; + auto Eq = static_cast(Constraint); + Eq->Source = Source; } addConstraint(NewConstraint); @@ -868,7 +891,7 @@ namespace bolt { for (auto Case: Match->Cases) { pushContext(Case->Ctx); auto PattTy = inferPattern(Case->Pattern); - addConstraint(new CEqual(PattTy, ValTy, X)); + addConstraint(new CEqual(PattTy, ValTy, Case)); auto ExprTy = inferExpression(Case->Expression); addConstraint(new CEqual(ExprTy, Ty, Case->Expression)); popContext(); diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index e362b7a87..e7dfeb9e8 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -118,6 +118,12 @@ namespace bolt { return "a function or variable reference"; case NodeKind::MatchExpression: return "a match-expression"; + case NodeKind::ConstantExpression: + return "a literal expression"; + case NodeKind::IfStatement: + return "an if-statement"; + case NodeKind::IfStatementPart: + return "a branch of an if-statement"; default: ZEN_UNREACHABLE }