bolt/src/Checker.cc

504 lines
13 KiB
C++
Raw Normal View History

2022-08-21 16:25:52 +02:00
#include <stack>
2022-08-21 20:56:58 +02:00
#include "bolt/Diagnostics.hpp"
2022-08-21 16:25:52 +02:00
#include "zen/config.hpp"
#include "bolt/CST.hpp"
#include "bolt/Checker.hpp"
namespace bolt {
std::string describe(const Type* Ty);
2022-08-21 20:56:58 +02:00
2022-08-21 16:25:52 +02:00
bool Type::hasTypeVar(const TVar* TV) {
switch (Kind) {
case TypeKind::Var:
return static_cast<TVar*>(this)->Id == TV->Id;
case TypeKind::Arrow:
{
auto Y = static_cast<TArrow*>(this);
for (auto Ty: Y->ParamTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return Y->ReturnType->hasTypeVar(TV);
}
2022-08-21 20:56:58 +02:00
case TypeKind::Con:
{
auto Y = static_cast<TCon*>(this);
for (auto Ty: Y->Args) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
case TypeKind::Any:
return false;
2022-08-21 16:25:52 +02:00
}
}
Type* Type::substitute(const TVSub &Sub) {
switch (Kind) {
case TypeKind::Var:
{
auto Y = static_cast<TVar*>(this);
auto Match = Sub.find(Y);
return Match != Sub.end() ? Match->second->substitute(Sub) : Y;
2022-08-21 16:25:52 +02:00
}
case TypeKind::Arrow:
{
auto Y = static_cast<TArrow*>(this);
std::vector<Type*> NewParamTypes;
for (auto Ty: Y->ParamTypes) {
NewParamTypes.push_back(Ty->substitute(Sub));
}
auto NewRetTy = Y->ReturnType->substitute(Sub) ;
return new TArrow(NewParamTypes, NewRetTy);
}
case TypeKind::Any:
return this;
case TypeKind::Con:
{
auto Y = static_cast<TCon*>(this);
std::vector<Type*> NewArgs;
for (auto Arg: Y->Args) {
NewArgs.push_back(Arg->substitute(Sub));
}
return new TCon(Y->Id, NewArgs, Y->DisplayName);
2022-08-21 16:25:52 +02:00
}
}
}
Constraint* Constraint::substitute(const TVSub &Sub) {
switch (Kind) {
case ConstraintKind::Equal:
{
auto Y = static_cast<CEqual*>(this);
return new CEqual(Y->Left->substitute(Sub), Y->Right->substitute(Sub), Y->Source);
}
case ConstraintKind::Many:
{
auto Y = static_cast<CMany*>(this);
auto NewConstraints = new ConstraintSet();
for (auto Element: Y->Constraints) {
NewConstraints->push_back(Element->substitute(Sub));
}
return new CMany(*NewConstraints);
}
case ConstraintKind::Empty:
return this;
}
}
Scheme* InferContext::lookup(ByteString Name) {
InferContext* Curr = this;
for (;;) {
auto Match = Curr->Env.find(Name);
if (Match != Curr->Env.end()) {
return &Match->second;
}
Curr = Curr->Parent;
if (Curr == nullptr) {
return nullptr;
}
}
}
Type* InferContext::lookupMono(ByteString Name) {
auto Scm = lookup(Name);
if (Scm == nullptr) {
return nullptr;
}
auto& F = Scm->as<Forall>();
ZEN_ASSERT(F.TVs == nullptr || F.TVs->empty());
return F.Type;
}
void InferContext::addBinding(ByteString Name, Scheme S) {
Env.emplace(Name, S);
}
2022-08-21 16:25:52 +02:00
void InferContext::addConstraint(Constraint *C) {
Constraints.push_back(C);
}
2022-08-21 20:56:58 +02:00
Checker::Checker(DiagnosticEngine& DE):
DE(DE) {}
2022-08-21 16:25:52 +02:00
void Checker::infer(Node* X, InferContext& Ctx) {
switch (X->Type) {
case NodeType::SourceFile:
{
auto Y = static_cast<SourceFile*>(X);
for (auto Element: Y->Elements) {
infer(Element, Ctx);
}
break;
}
case NodeType::LetDeclaration:
{
auto Y = static_cast<LetDeclaration*>(X);
auto NewCtx = new InferContext { Ctx };
Type* Ty;
if (Y->TypeAssert) {
Ty = inferTypeExpression(Y->TypeAssert->TypeExpression, *NewCtx);
} else {
Ty = createTypeVar(*NewCtx);
}
std::vector<Type*> ParamTypes;
Type* RetType;
for (auto Param: Y->Params) {
// TODO incorporate Param->TypeAssert or make it a kind of pattern
TVar* TV = createTypeVar(*NewCtx);
TVSet NoTVs;
ConstraintSet NoConstraints;
inferBindings(Param->Pattern, TV, *NewCtx, NoConstraints, NoTVs);
ParamTypes.push_back(TV);
}
if (Y->Body) {
switch (Y->Body->Type) {
case NodeType::LetExprBody:
{
auto Z = static_cast<LetExprBody*>(Y->Body);
RetType = inferExpression(Z->Expression, *NewCtx);
break;
}
case NodeType::LetBlockBody:
{
auto Z = static_cast<LetBlockBody*>(Y->Body);
RetType = createTypeVar(*NewCtx);
for (auto Element: Z->Elements) {
infer(Element, *NewCtx);
}
break;
}
default:
ZEN_UNREACHABLE
}
} else {
RetType = createTypeVar(*NewCtx);
}
NewCtx->addConstraint(new CEqual { Ty, new TArrow(ParamTypes, RetType), X });
inferBindings(Y->Pattern, Ty, Ctx, NewCtx->Constraints, NewCtx->TVs);
2022-08-21 16:25:52 +02:00
break;
}
case NodeType::ExpressionStatement:
{
auto Y = static_cast<ExpressionStatement*>(X);
inferExpression(Y->Expression, Ctx);
break;
}
default:
ZEN_UNREACHABLE
}
}
TVar* Checker::createTypeVar(InferContext& Ctx) {
auto TV = new TVar(nextTypeVarId++);
Ctx.TVs.emplace(TV);
return TV;
2022-08-21 16:25:52 +02:00
}
Type* Checker::instantiate(Scheme& S, InferContext& Ctx, Node* Source) {
2022-08-21 16:25:52 +02:00
switch (S.getKind()) {
case SchemeKind::Forall:
{
auto& F = S.as<Forall>();
2022-08-21 16:25:52 +02:00
TVSub Sub;
if (F.TVs) {
for (auto TV: *F.TVs) {
Sub[TV] = createTypeVar(Ctx);
}
2022-08-21 16:25:52 +02:00
}
if (F.Constraints) {
for (auto Constraint: *F.Constraints) {
auto NewConstraint = Constraint->substitute(Sub);
// This makes error messages prettier by relating the typing failure
// to the call site rather than the definition.
if (NewConstraint->getKind() == ConstraintKind::Equal) {
static_cast<CEqual *>(NewConstraint)->Source = Source;
}
Ctx.addConstraint(NewConstraint);
}
}
2022-08-21 16:25:52 +02:00
return F.Type->substitute(Sub);
}
}
}
Type* Checker::inferTypeExpression(TypeExpression* X, InferContext& Ctx) {
switch (X->Type) {
case NodeType::ReferenceTypeExpression:
{
auto Y = static_cast<ReferenceTypeExpression*>(X);
auto Ty = Ctx.lookupMono(Y->Name->Name->Text);
if (Ty == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name->Name);
return new TAny();
}
return Ty;
}
case NodeType::ArrowTypeExpression:
{
auto Y = static_cast<ArrowTypeExpression*>(X);
std::vector<Type*> ParamTypes;
for (auto ParamType: Y->ParamTypes) {
ParamTypes.push_back(inferTypeExpression(ParamType, Ctx));
}
auto ReturnType = inferTypeExpression(Y->ReturnType, Ctx);
return new TArrow(ParamTypes, ReturnType);
}
default:
ZEN_UNREACHABLE
}
}
2022-08-21 16:25:52 +02:00
Type* Checker::inferExpression(Expression* X, InferContext& Ctx) {
switch (X->Type) {
case NodeType::ConstantExpression:
{
auto Y = static_cast<ConstantExpression*>(X);
2022-08-21 20:56:58 +02:00
Type* Ty = nullptr;
2022-08-21 16:25:52 +02:00
switch (Y->Token->Type) {
case NodeType::IntegerLiteral:
Ty = Ctx.lookupMono("Int");
2022-08-21 20:56:58 +02:00
break;
2022-08-21 16:25:52 +02:00
case NodeType::StringLiteral:
Ty = Ctx.lookupMono("String");
2022-08-21 20:56:58 +02:00
break;
2022-08-21 16:25:52 +02:00
default:
ZEN_UNREACHABLE
}
2022-08-21 20:56:58 +02:00
ZEN_ASSERT(Ty != nullptr);
return Ty;
2022-08-21 16:25:52 +02:00
}
case NodeType::ReferenceExpression:
{
auto Y = static_cast<ReferenceExpression*>(X);
ZEN_ASSERT(Y->Name->ModulePath.empty());
auto Scm = Ctx.lookup(Y->Name->Name->Text);
2022-08-21 16:25:52 +02:00
if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name);
2022-08-21 16:25:52 +02:00
return new TAny();
}
return instantiate(*Scm, Ctx, X);
}
case NodeType::CallExpression:
{
auto Y = static_cast<CallExpression*>(X);
auto OpTy = inferExpression(Y->Function, Ctx);
auto RetType = createTypeVar(Ctx);
std::vector<Type*> ArgTypes;
for (auto Arg: Y->Args) {
ArgTypes.push_back(inferExpression(Arg, Ctx));
}
Ctx.addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X });
return RetType;
2022-08-21 16:25:52 +02:00
}
case NodeType::InfixExpression:
{
auto Y = static_cast<InfixExpression*>(X);
auto Scm = Ctx.lookup(Y->Operator->getText());
2022-08-21 16:25:52 +02:00
if (Scm == nullptr) {
2022-08-21 20:56:58 +02:00
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator);
2022-08-21 16:25:52 +02:00
return new TAny();
}
auto OpTy = instantiate(*Scm, Ctx, Y->Operator);
auto RetTy = createTypeVar(Ctx);
2022-08-21 16:25:52 +02:00
std::vector<Type*> ArgTys;
ArgTys.push_back(inferExpression(Y->LHS, Ctx));
ArgTys.push_back(inferExpression(Y->RHS, Ctx));
2022-08-24 20:57:26 +02:00
Ctx.addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X });
2022-08-21 16:25:52 +02:00
return RetTy;
}
default:
ZEN_UNREACHABLE
}
}
void Checker::inferBindings(Pattern* Pattern, Type* Type, InferContext& Ctx, ConstraintSet& Constraints, TVSet& TVs) {
switch (Pattern->Type) {
case NodeType::BindPattern:
Ctx.addBinding(static_cast<BindPattern*>(Pattern)->Name->Text, Forall(TVs, Constraints, Type));
break;
default:
ZEN_UNREACHABLE
}
}
2022-08-21 16:25:52 +02:00
void Checker::check(SourceFile *SF) {
InferContext Toplevel;
2022-08-21 20:56:58 +02:00
auto StringTy = new TCon(nextConTypeId++, {}, "String");
auto IntTy = new TCon(nextConTypeId++, {}, "Int");
auto BoolTy = new TCon(nextConTypeId++, {}, "Bool");
Toplevel.addBinding("String", Forall(StringTy));
Toplevel.addBinding("Int", Forall(IntTy));
Toplevel.addBinding("Bool", Forall(BoolTy));
Toplevel.addBinding("+", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
Toplevel.addBinding("-", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
Toplevel.addBinding("*", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
Toplevel.addBinding("/", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
2022-08-21 16:25:52 +02:00
infer(SF, Toplevel);
solve(new CMany(Toplevel.Constraints));
2022-08-21 16:25:52 +02:00
}
void Checker::solve(Constraint* Constraint) {
std::stack<class Constraint*> Queue;
2022-08-21 20:56:58 +02:00
Queue.push(Constraint);
TVSub Solution;
2022-08-21 16:25:52 +02:00
while (!Queue.empty()) {
auto Constraint = Queue.top();
Queue.pop();
switch (Constraint->getKind()) {
case ConstraintKind::Empty:
break;
case ConstraintKind::Many:
{
auto Y = static_cast<CMany*>(Constraint);
for (auto Constraint: Y->Constraints) {
Queue.push(Constraint);
}
break;
}
case ConstraintKind::Equal:
{
auto Y = static_cast<CEqual*>(Constraint);
std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl;
if (!unify(Y->Left, Y->Right, Solution)) {
DE.add<UnificationErrorDiagnostic>(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source);
2022-08-21 16:25:52 +02:00
}
break;
}
}
}
}
bool Checker::unify(Type* A, Type* B, TVSub& Solution) {
if (A->getKind() == TypeKind::Var) {
auto Match = Solution.find(static_cast<TVar*>(A));
if (Match != Solution.end()) {
A = Match->second;
}
}
if (B->getKind() == TypeKind::Var) {
auto Match = Solution.find(static_cast<TVar*>(B));
if (Match != Solution.end()) {
B = Match->second;
}
}
if (A->getKind() == TypeKind::Var) {
auto Y = static_cast<TVar*>(A);
if (B->hasTypeVar(Y)) {
// TODO occurs check
return false;
2022-08-21 16:25:52 +02:00
}
Solution[Y] = B;
return true;
}
if (B->getKind() == TypeKind::Var) {
return unify(B, A, Solution);
}
if (A->getKind() == TypeKind::Any || B->getKind() == TypeKind::Any) {
return true;
}
2022-08-21 16:25:52 +02:00
if (A->getKind() == TypeKind::Arrow && B->getKind() == TypeKind::Arrow) {
auto Y = static_cast<TArrow*>(A);
auto Z = static_cast<TArrow*>(B);
if (Y->ParamTypes.size() != Z->ParamTypes.size()) {
return false;
}
auto Count = Y->ParamTypes.size();
for (std::size_t I = 0; I < Count; I++) {
if (!unify(Y->ParamTypes[I], Z->ParamTypes[I], Solution)) {
return false;
}
}
return unify(Y->ReturnType, Z->ReturnType, Solution);
}
if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) {
2022-08-21 16:25:52 +02:00
auto Y = static_cast<TCon*>(A);
auto Z = static_cast<TCon*>(B);
if (Y->Id != Z->Id) {
return false;
}
ZEN_ASSERT(Y->Args.size() == Z->Args.size());
auto Count = Y->Args.size();
for (std::size_t I = 0; I < Count; I++) {
if (!unify(Y->Args[I], Z->Args[I], Solution)) {
return false;
}
}
return true;
}
return false;
}
}