Type-check match-expressions, nested expressions and literal patterns

Also introduces '$' as a new binding
This commit is contained in:
Sam Vervaeck 2024-07-11 10:09:28 +02:00
parent d4af7f5059
commit 4b9fbc1d0c
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
2 changed files with 93 additions and 17 deletions

View file

@ -117,7 +117,7 @@ public:
Type* instantiate(TypeScheme* Scm); Type* instantiate(TypeScheme* Scm);
void visitPattern(Pattern* P, Type* Ty, TypeEnv& Out); ConstraintSet visitPattern(Pattern* P, Type* Ty, TypeEnv& Out);
ConstraintSet inferSourceFile(TypeEnv& Env, SourceFile* SF); ConstraintSet inferSourceFile(TypeEnv& Env, SourceFile* SF);

View file

@ -6,8 +6,10 @@
#include "bolt/CST.hpp" #include "bolt/CST.hpp"
#include "bolt/Type.hpp" #include "bolt/Type.hpp"
#include "bolt/Diagnostics.hpp" #include "bolt/Diagnostics.hpp"
#include <algorithm>
#include <cwchar> #include <cwchar>
#include <functional> #include <functional>
#include <variant>
#include "bolt/Checker.hpp" #include "bolt/Checker.hpp"
namespace bolt { namespace bolt {
@ -109,6 +111,42 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
switch (Expr->getKind()) { switch (Expr->getKind()) {
case NodeKind::MatchExpression:
{
auto E = static_cast<MatchExpression*>(Expr);
Type* MatchTy;
if (E->hasValue()) {
auto [ValOut, ValTy] = inferExpr(Env, E->getValue(), RetTy);
mergeTo(Out, ValOut);
MatchTy = ValTy;
} else {
MatchTy = createTVar();
}
Ty = createTVar();
for (auto Case: E->Cases) {
TypeEnv NewEnv { Env };
auto PattOut = visitPattern(Case->Pattern, MatchTy, NewEnv);
mergeTo(Out, PattOut);
auto [ExprOut, ExprTy] = inferExpr(NewEnv, Case->Expression, RetTy);
mergeTo(Out, ExprOut);
Out.push_back(new CTypesEqual { ExprTy, Ty, Case->Expression });
}
if (E->Value) {
auto ParamTy = createTVar();
Ty = new TFun(ParamTy, Ty);
}
break;
}
case NodeKind::NestedExpression:
{
auto E = static_cast<NestedExpression*>(Expr);
auto [ExprOut, ExprTy] = inferExpr(Env, E->Inner, RetTy);
mergeTo(Out, ExprOut);
Ty = ExprTy;
break;
}
case NodeKind::FunctionExpression: case NodeKind::FunctionExpression:
{ {
auto E = static_cast<FunctionExpression*>(Expr); auto E = static_cast<FunctionExpression*>(Expr);
@ -117,7 +155,8 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
TypeEnv NewEnv { Env }; TypeEnv NewEnv { Env };
for (auto P: E->getParameters()) { for (auto P: E->getParameters()) {
auto TV = createTVar(); auto TV = createTVar();
visitPattern(P, TV, NewEnv); auto ParamOut = visitPattern(P, TV, NewEnv);
mergeTo(Out, ParamOut);
Ty = new TFun(TV, Ty); Ty = new TFun(TV, Ty);
} }
auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy); auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy);
@ -235,18 +274,45 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
return { Out, Ty }; return { Out, Ty };
} }
void Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& Out) { ConstraintSet Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& ToInsert) {
ConstraintSet Out;
switch (P->getKind()) { switch (P->getKind()) {
case NodeKind::BindPattern: case NodeKind::BindPattern:
{ {
auto Q = static_cast<BindPattern*>(P); auto Q = static_cast<BindPattern*>(P);
// TODO Make a TypedNode out of a Pattern? // TODO Make a TypedNode out of a Pattern?
Out.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var); ToInsert.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var);
break; break;
} }
case NodeKind::LiteralPattern:
{
auto Lit = static_cast<LiteralPattern*>(P);
Type* LitTy;
switch (Lit->Literal->getKind()) {
case NodeKind::StringLiteral:
LitTy = getStringType();
break;
case NodeKind::IntegerLiteral:
LitTy = getIntType();
break;
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
Out.push_back(new CTypesEqual { Ty, LitTy, Lit });
break;
}
default:
ZEN_UNREACHABLE
}
return Out;
} }
std::tuple<ConstraintSet, Type*> Checker::inferTypeExpr(TypeEnv& Env, TypeExpression* TE) { std::tuple<ConstraintSet, Type*> Checker::inferTypeExpr(TypeEnv& Env, TypeExpression* TE) {
@ -296,7 +362,8 @@ ConstraintSet Checker::inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaratio
for (auto It = Params.end(); It-- != Params.begin(); ) { for (auto It = Params.end(); It-- != Params.begin(); ) {
auto Param = *It; auto Param = *It;
auto ParamTy = createTVar(); auto ParamTy = createTVar();
visitPattern(Param->Pattern, ParamTy, NewEnv); auto ParamOut = visitPattern(Param->Pattern, ParamTy, NewEnv);
mergeTo(Out, ParamOut);
Ty = new TFun(ParamTy, Ty); Ty = new TFun(ParamTy, Ty);
} }
@ -418,13 +485,15 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
Node* From; Node* From;
void visitReferenceExpression(ReferenceExpression* E) { void visitReferenceExpression(ReferenceExpression* E) {
auto To = E->getScope()->lookup(E->getSymbolPath()); auto To = E->getScope()->lookup(E->getSymbolPath());
if (To != nullptr) { if (To) {
if (isa<Parameter>(To)) { if (isa<Parameter>(To)) {
To = To->Parent; To = To->Parent;
} }
if (isa<FunctionDeclaration>(To) || isa<VariableDeclaration>(To)) {
G.add_edge(From, To); G.add_edge(From, To);
} }
} }
}
} V { {}, G, From }; } V { {}, G, From };
V.visit(N); V.visit(N);
}; };
@ -433,16 +502,16 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
for (auto Element: Elements) { for (auto Element: Elements) {
if (isa<FunctionDeclaration>(Element)) { if (isa<FunctionDeclaration>(Element)) {
auto M = static_cast<FunctionDeclaration*>(Element); auto Decl = static_cast<FunctionDeclaration*>(Element);
G.add_vertex(Element); G.add_vertex(Decl);
if (M->hasBody()) { if (Decl->hasBody()) {
populate(M, M->getBody()); populate(Decl, Decl->getBody());
} }
} else if (isa<VariableDeclaration>(Element)) { } else if (isa<VariableDeclaration>(Element)) {
auto M = static_cast<VariableDeclaration*>(Element); auto Decl = static_cast<VariableDeclaration*>(Element);
G.add_vertex(Element); G.add_vertex(Decl);
if (M->hasExpression()) { if (Decl->hasExpression()) {
populate(M, M->getExpression()); populate(Decl, Decl->getExpression());
} }
} else { } else {
Stmts.push_back(Element); Stmts.push_back(Element);
@ -545,19 +614,23 @@ ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected,
case NodeKind::FunctionExpression: case NodeKind::FunctionExpression:
{ {
ConstraintSet Out;
auto E = static_cast<FunctionExpression*>(Expr); auto E = static_cast<FunctionExpression*>(Expr);
// FIXME save RetTy on the node and re-use it in this function? // FIXME save RetTy on the node and re-use it in this function?
if (Expected->getKind() == TypeKind::Fun) { if (Expected->getKind() == TypeKind::Fun) {
TypeEnv NewEnv { Env }; TypeEnv NewEnv { Env };
TFun* Ty = static_cast<TFun*>(Expected); TFun* Ty = static_cast<TFun*>(Expected);
for (auto P: E->getParameters()) { for (auto P: E->getParameters()) {
visitPattern(P, Ty->getLeft(), NewEnv); auto ParamOut = visitPattern(P, Ty->getLeft(), NewEnv);
mergeTo(Out, ParamOut);
if (Ty->getRight()->getKind() != TypeKind::Fun) { if (Ty->getRight()->getKind() != TypeKind::Fun) {
goto fallback; goto fallback;
} }
Ty = static_cast<TFun*>(Ty->getRight()); Ty = static_cast<TFun*>(Ty->getRight());
} }
return checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight()); auto ExprOut = checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight());
mergeTo(Out, ExprOut);
return Out;
} }
goto fallback; goto fallback;
} }
@ -619,6 +692,8 @@ void Checker::unifyTypeType(Type* A, Type* B, Node* N) {
void Checker::run(SourceFile* SF) { void Checker::run(SourceFile* SF) {
TypeEnv Env; TypeEnv Env;
auto A = createTVar();
auto B = createTVar();
Env.add("Int", getIntType(), SymbolKind::Type); Env.add("Int", getIntType(), SymbolKind::Type);
Env.add("Bool", getBoolType(), SymbolKind::Type); Env.add("Bool", getBoolType(), SymbolKind::Type);
Env.add("String", getStringType(), SymbolKind::Type); Env.add("String", getStringType(), SymbolKind::Type);
@ -626,6 +701,7 @@ void Checker::run(SourceFile* SF) {
Env.add("False", getBoolType(), SymbolKind::Var); Env.add("False", getBoolType(), SymbolKind::Var);
Env.add("+", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var); Env.add("+", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var);
Env.add("-", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var); Env.add("-", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var);
Env.add("$", new TypeScheme({ A, B }, new TFun(new TFun(A, B), new TFun(A, B))), SymbolKind::Var);
auto Out = inferSourceFile(Env, SF); auto Out = inferSourceFile(Env, SF);
solve(Out); solve(Out);
} }