Eagerly solve some constraints when certain conditions are met

This commit is contained in:
Sam Vervaeck 2023-05-23 20:09:05 +02:00
parent acbfeb8975
commit b8e989d03f
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY

View file

@ -127,32 +127,34 @@ namespace bolt {
{ {
auto Y = static_cast<CEqual*>(C); auto Y = static_cast<CEqual*>(C);
// std::size_t MaxLevel = 0; // FIXME this logic breaks id x
// for (std::size_t I = Contexts.size(); I-- > 0; ) {
// auto Ctx = Contexts[I];
// if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
// MaxLevel = I;
// break;
// }
// }
// std::size_t MinLevel = MaxLevel; std::size_t MaxLevel = 0;
// for (std::size_t I = 0; I < Contexts.size(); I++) { for (std::size_t I = Contexts.size(); I-- > 0; ) {
// auto Ctx = Contexts[I]; auto Ctx = Contexts[I];
// if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
// MinLevel = I; MaxLevel = I;
// break; break;
// } }
// } }
// // TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0 std::size_t MinLevel = MaxLevel;
// if (MaxLevel == MinLevel) { for (std::size_t I = 0; I < Contexts.size(); I++) {
// solveCEqual(Y); auto Ctx = Contexts[I];
// } else { if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) {
// Contexts[MaxLevel]->Constraints->push_back(C); MinLevel = I;
// } break;
}
}
Contexts.back()->Constraints->push_back(C); // TODO detect if MaxLevelLeft == 0 or MaxLevelRight == 0
if (MaxLevel == MinLevel) {
solveCEqual(Y);
} else {
Contexts[MaxLevel]->Constraints->push_back(C);
}
// Contexts.back()->Constraints->push_back(C);
break; break;
} }
case ConstraintKind::Many: case ConstraintKind::Many:
@ -514,13 +516,18 @@ namespace bolt {
// This makes error messages prettier by relating the typing failure // This makes error messages prettier by relating the typing failure
// to the call site rather than the definition. // to the call site rather than the definition.
if (NewConstraint->getKind() == ConstraintKind::Equal) { if (NewConstraint->getKind() == ConstraintKind::Equal) {
static_cast<CEqual *>(NewConstraint)->Source = Source; static_cast<CEqual*>(NewConstraint)->Source = Source;
} }
addConstraint(NewConstraint); addConstraint(NewConstraint);
} }
return F->Type->substitute(Sub); // Note the call to solve? This is because constraints may have already
// been solved, with some unification variables being erased. To make
// sure we instantiate unification variables that are still in use
// we solve before substituting.
// TODO perform a full solve()
return find(F->Type)->substitute(Sub);
} }
} }
@ -613,6 +620,8 @@ namespace bolt {
Type* Checker::inferExpression(Expression* X) { Type* Checker::inferExpression(Expression* X) {
Type* Ty;
switch (X->getKind()) { switch (X->getKind()) {
case NodeKind::MatchExpression: case NodeKind::MatchExpression:
@ -624,27 +633,26 @@ namespace bolt {
} else { } else {
ValTy = createTypeVar(); ValTy = createTypeVar();
} }
auto ResTy = createTypeVar(); Ty = createTypeVar();
for (auto Case: Match->Cases) { for (auto Case: Match->Cases) {
auto NewCtx = createInferContext(); auto NewCtx = createInferContext();
Contexts.push_back(NewCtx); Contexts.push_back(NewCtx);
inferBindings(Case->Pattern, ValTy); inferBindings(Case->Pattern, ValTy);
auto Ty = inferExpression(Case->Expression); auto Ty = inferExpression(Case->Expression);
addConstraint(new CEqual(Ty, ResTy, Case->Expression)); addConstraint(new CEqual(Ty, Ty, Case->Expression));
Contexts.pop_back(); Contexts.pop_back();
} }
if (!Match->Value) { if (!Match->Value) {
return new TArrow({ ValTy }, ResTy); Ty = new TArrow({ ValTy }, Ty);
} }
return ResTy; break;
} }
case NodeKind::ConstantExpression: case NodeKind::ConstantExpression:
{ {
auto Const = static_cast<ConstantExpression*>(X); auto Const = static_cast<ConstantExpression*>(X);
auto Ty = inferLiteral(Const->Token); Ty = inferLiteral(Const->Token);
X->setType(Ty); break;
return Ty;
} }
case NodeKind::ReferenceExpression: case NodeKind::ReferenceExpression:
@ -662,23 +670,21 @@ namespace bolt {
DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name); DE.add<BindingNotFoundDiagnostic>(Ref->Name->getCanonicalText(), Ref->Name);
return createTypeVar(); return createTypeVar();
} }
auto Ty = instantiate(Scm, X); Ty = instantiate(Scm, X);
X->setType(Ty); break;
return Ty;
} }
case NodeKind::CallExpression: case NodeKind::CallExpression:
{ {
auto Call = static_cast<CallExpression*>(X); auto Call = static_cast<CallExpression*>(X);
auto OpTy = inferExpression(Call->Function); auto OpTy = inferExpression(Call->Function);
auto RetType = createTypeVar(); Ty = createTypeVar();
std::vector<Type*> ArgTypes; std::vector<Type*> ArgTypes;
for (auto Arg: Call->Args) { for (auto Arg: Call->Args) {
ArgTypes.push_back(inferExpression(Arg)); ArgTypes.push_back(inferExpression(Arg));
} }
addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X }); addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, Ty), X });
X->setType(RetType); break;
return RetType;
} }
case NodeKind::InfixExpression: case NodeKind::InfixExpression:
@ -690,13 +696,12 @@ namespace bolt {
return createTypeVar(); return createTypeVar();
} }
auto OpTy = instantiate(Scm, Infix->Operator); auto OpTy = instantiate(Scm, Infix->Operator);
auto RetTy = createTypeVar(); auto Ty = createTypeVar();
std::vector<Type*> ArgTys; std::vector<Type*> ArgTys;
ArgTys.push_back(inferExpression(Infix->LHS)); ArgTys.push_back(inferExpression(Infix->LHS));
ArgTys.push_back(inferExpression(Infix->RHS)); ArgTys.push_back(inferExpression(Infix->RHS));
addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X }); addConstraint(new CEqual { new TArrow(ArgTys, Ty), OpTy, X });
X->setType(RetTy); break;
return RetTy;
} }
case NodeKind::TupleExpression: case NodeKind::TupleExpression:
@ -706,7 +711,8 @@ namespace bolt {
for (auto [E, Comma]: Tuple->Elements) { for (auto [E, Comma]: Tuple->Elements) {
Types.push_back(inferExpression(E)); Types.push_back(inferExpression(E));
} }
return new TTuple(Types); Ty = new TTuple(Types);
break;
} }
case NodeKind::MemberExpression: case NodeKind::MemberExpression:
@ -716,7 +722,8 @@ namespace bolt {
case NodeKind::IntegerLiteral: case NodeKind::IntegerLiteral:
{ {
auto I = static_cast<IntegerLiteral*>(Member->Name); auto I = static_cast<IntegerLiteral*>(Member->Name);
return new TTupleIndex(inferExpression(Member->E), I->getInteger()); Ty = new TTupleIndex(inferExpression(Member->E), I->getInteger());
break;
} }
case NodeKind::Identifier: case NodeKind::Identifier:
{ {
@ -731,7 +738,8 @@ namespace bolt {
case NodeKind::NestedExpression: case NodeKind::NestedExpression:
{ {
auto Nested = static_cast<NestedExpression*>(X); auto Nested = static_cast<NestedExpression*>(X);
return inferExpression(Nested->Inner); Ty = inferExpression(Nested->Inner);
break;
} }
default: default:
@ -739,6 +747,9 @@ namespace bolt {
} }
// Ty = find(Ty);
X->setType(Ty);
return Ty;
} }
void Checker::inferBindings( void Checker::inferBindings(
@ -1157,13 +1168,17 @@ namespace bolt {
// called, it may decide to solve the constraint immediately during // called, it may decide to solve the constraint immediately during
// inference. If this happens, a type variable might get assigned a concrete // inference. If this happens, a type variable might get assigned a concrete
// type such as Int. We therefore never want the variable to be polymorphic // type such as Int. We therefore never want the variable to be polymorphic
// and be instantiated with a fresh variable, as it has already been solved. // and be instantiated with a fresh variable, as that would allow Bool to
// collide with Int.
//
// Should it get assigned another unification variable, that's OK too // Should it get assigned another unification variable, that's OK too
// because then the inference context of that variable is what matters and // because then that variable is what matters and it will become the new
// not anymore the context of this one. // (possibly polymorphic) variable.
// if (!Contexts.empty()) { if (!Contexts.empty()) {
// Contexts.back()->TVs->erase(TV); // std::cerr << "erase " << describe(TV) << std::endl;
// } auto TVs = Contexts.back()->TVs;
TVs->erase(TV);
}
} }