diff --git a/src/Checker.cc b/src/Checker.cc index 231573889..35c17a584 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -127,32 +127,34 @@ namespace bolt { { auto Y = static_cast(C); -// std::size_t MaxLevel = 0; -// for (std::size_t I = Contexts.size(); I-- > 0; ) { -// auto Ctx = Contexts[I]; -// if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { -// MaxLevel = I; -// break; -// } -// } + // FIXME this logic breaks id x -// std::size_t MinLevel = MaxLevel; -// for (std::size_t I = 0; I < Contexts.size(); I++) { -// auto Ctx = Contexts[I]; -// if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { -// MinLevel = I; -// break; -// } -// } + std::size_t MaxLevel = 0; + for (std::size_t I = Contexts.size(); I-- > 0; ) { + auto Ctx = Contexts[I]; + if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { + MaxLevel = I; + break; + } + } -// // TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0 -// if (MaxLevel == MinLevel) { -// solveCEqual(Y); -// } else { -// Contexts[MaxLevel]->Constraints->push_back(C); -// } + std::size_t MinLevel = MaxLevel; + for (std::size_t I = 0; I < Contexts.size(); I++) { + auto Ctx = Contexts[I]; + if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { + MinLevel = I; + break; + } + } - Contexts.back()->Constraints->push_back(C); + // TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0 + if (MaxLevel == MinLevel) { + solveCEqual(Y); + } else { + Contexts[MaxLevel]->Constraints->push_back(C); + } + + // Contexts.back()->Constraints->push_back(C); break; } case ConstraintKind::Many: @@ -514,13 +516,18 @@ namespace bolt { // This makes error messages prettier by relating the typing failure // to the call site rather than the definition. if (NewConstraint->getKind() == ConstraintKind::Equal) { - static_cast(NewConstraint)->Source = Source; + static_cast(NewConstraint)->Source = Source; } addConstraint(NewConstraint); } - return F->Type->substitute(Sub); + // Note the call to solve? This is because constraints may have already + // been solved, with some unification variables being erased. To make + // sure we instantiate unification variables that are still in use + // we solve before substituting. + // TODO perform a full solve() + return find(F->Type)->substitute(Sub); } } @@ -613,6 +620,8 @@ namespace bolt { Type* Checker::inferExpression(Expression* X) { + Type* Ty; + switch (X->getKind()) { case NodeKind::MatchExpression: @@ -624,27 +633,26 @@ namespace bolt { } else { ValTy = createTypeVar(); } - auto ResTy = createTypeVar(); + Ty = createTypeVar(); for (auto Case: Match->Cases) { auto NewCtx = createInferContext(); Contexts.push_back(NewCtx); inferBindings(Case->Pattern, ValTy); auto Ty = inferExpression(Case->Expression); - addConstraint(new CEqual(Ty, ResTy, Case->Expression)); + addConstraint(new CEqual(Ty, Ty, Case->Expression)); Contexts.pop_back(); } if (!Match->Value) { - return new TArrow({ ValTy }, ResTy); + Ty = new TArrow({ ValTy }, Ty); } - return ResTy; + break; } case NodeKind::ConstantExpression: { auto Const = static_cast(X); - auto Ty = inferLiteral(Const->Token); - X->setType(Ty); - return Ty; + Ty = inferLiteral(Const->Token); + break; } case NodeKind::ReferenceExpression: @@ -662,23 +670,21 @@ namespace bolt { DE.add(Ref->Name->getCanonicalText(), Ref->Name); return createTypeVar(); } - auto Ty = instantiate(Scm, X); - X->setType(Ty); - return Ty; + Ty = instantiate(Scm, X); + break; } case NodeKind::CallExpression: { auto Call = static_cast(X); auto OpTy = inferExpression(Call->Function); - auto RetType = createTypeVar(); + Ty = createTypeVar(); std::vector ArgTypes; for (auto Arg: Call->Args) { ArgTypes.push_back(inferExpression(Arg)); } - addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X }); - X->setType(RetType); - return RetType; + addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, Ty), X }); + break; } case NodeKind::InfixExpression: @@ -690,13 +696,12 @@ namespace bolt { return createTypeVar(); } auto OpTy = instantiate(Scm, Infix->Operator); - auto RetTy = createTypeVar(); + auto Ty = createTypeVar(); std::vector ArgTys; ArgTys.push_back(inferExpression(Infix->LHS)); ArgTys.push_back(inferExpression(Infix->RHS)); - addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X }); - X->setType(RetTy); - return RetTy; + addConstraint(new CEqual { new TArrow(ArgTys, Ty), OpTy, X }); + break; } case NodeKind::TupleExpression: @@ -706,7 +711,8 @@ namespace bolt { for (auto [E, Comma]: Tuple->Elements) { Types.push_back(inferExpression(E)); } - return new TTuple(Types); + Ty = new TTuple(Types); + break; } case NodeKind::MemberExpression: @@ -716,7 +722,8 @@ namespace bolt { case NodeKind::IntegerLiteral: { auto I = static_cast(Member->Name); - return new TTupleIndex(inferExpression(Member->E), I->getInteger()); + Ty = new TTupleIndex(inferExpression(Member->E), I->getInteger()); + break; } case NodeKind::Identifier: { @@ -731,7 +738,8 @@ namespace bolt { case NodeKind::NestedExpression: { auto Nested = static_cast(X); - return inferExpression(Nested->Inner); + Ty = inferExpression(Nested->Inner); + break; } default: @@ -739,6 +747,9 @@ namespace bolt { } + // Ty = find(Ty); + X->setType(Ty); + return Ty; } void Checker::inferBindings( @@ -1157,13 +1168,17 @@ namespace bolt { // called, it may decide to solve the constraint immediately during // inference. If this happens, a type variable might get assigned a concrete // type such as Int. We therefore never want the variable to be polymorphic - // and be instantiated with a fresh variable, as it has already been solved. + // and be instantiated with a fresh variable, as that would allow Bool to + // collide with Int. + // // Should it get assigned another unification variable, that's OK too - // because then the inference context of that variable is what matters and - // not anymore the context of this one. - // if (!Contexts.empty()) { - // Contexts.back()->TVs->erase(TV); - // } + // because then that variable is what matters and it will become the new + // (possibly polymorphic) variable. + if (!Contexts.empty()) { + // std::cerr << "erase " << describe(TV) << std::endl; + auto TVs = Contexts.back()->TVs; + TVs->erase(TV); + } }