Type-check match-expressions, nested expressions and literal patterns

Also introduces '$' as a new binding
This commit is contained in:
Sam Vervaeck 2024-07-11 10:09:28 +02:00
parent d4af7f5059
commit 4b9fbc1d0c
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
2 changed files with 93 additions and 17 deletions

View file

@ -117,7 +117,7 @@ public:
Type* instantiate(TypeScheme* Scm);
void visitPattern(Pattern* P, Type* Ty, TypeEnv& Out);
ConstraintSet visitPattern(Pattern* P, Type* Ty, TypeEnv& Out);
ConstraintSet inferSourceFile(TypeEnv& Env, SourceFile* SF);

View file

@ -6,8 +6,10 @@
#include "bolt/CST.hpp"
#include "bolt/Type.hpp"
#include "bolt/Diagnostics.hpp"
#include <algorithm>
#include <cwchar>
#include <functional>
#include <variant>
#include "bolt/Checker.hpp"
namespace bolt {
@ -109,6 +111,42 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
switch (Expr->getKind()) {
case NodeKind::MatchExpression:
{
auto E = static_cast<MatchExpression*>(Expr);
Type* MatchTy;
if (E->hasValue()) {
auto [ValOut, ValTy] = inferExpr(Env, E->getValue(), RetTy);
mergeTo(Out, ValOut);
MatchTy = ValTy;
} else {
MatchTy = createTVar();
}
Ty = createTVar();
for (auto Case: E->Cases) {
TypeEnv NewEnv { Env };
auto PattOut = visitPattern(Case->Pattern, MatchTy, NewEnv);
mergeTo(Out, PattOut);
auto [ExprOut, ExprTy] = inferExpr(NewEnv, Case->Expression, RetTy);
mergeTo(Out, ExprOut);
Out.push_back(new CTypesEqual { ExprTy, Ty, Case->Expression });
}
if (E->Value) {
auto ParamTy = createTVar();
Ty = new TFun(ParamTy, Ty);
}
break;
}
case NodeKind::NestedExpression:
{
auto E = static_cast<NestedExpression*>(Expr);
auto [ExprOut, ExprTy] = inferExpr(Env, E->Inner, RetTy);
mergeTo(Out, ExprOut);
Ty = ExprTy;
break;
}
case NodeKind::FunctionExpression:
{
auto E = static_cast<FunctionExpression*>(Expr);
@ -117,7 +155,8 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
TypeEnv NewEnv { Env };
for (auto P: E->getParameters()) {
auto TV = createTVar();
visitPattern(P, TV, NewEnv);
auto ParamOut = visitPattern(P, TV, NewEnv);
mergeTo(Out, ParamOut);
Ty = new TFun(TV, Ty);
}
auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy);
@ -235,18 +274,45 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
return { Out, Ty };
}
void Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& Out) {
ConstraintSet Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& ToInsert) {
ConstraintSet Out;
switch (P->getKind()) {
case NodeKind::BindPattern:
{
auto Q = static_cast<BindPattern*>(P);
// TODO Make a TypedNode out of a Pattern?
Out.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var);
ToInsert.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var);
break;
}
case NodeKind::LiteralPattern:
{
auto Lit = static_cast<LiteralPattern*>(P);
Type* LitTy;
switch (Lit->Literal->getKind()) {
case NodeKind::StringLiteral:
LitTy = getStringType();
break;
case NodeKind::IntegerLiteral:
LitTy = getIntType();
break;
default:
ZEN_UNREACHABLE
}
Out.push_back(new CTypesEqual { Ty, LitTy, Lit });
break;
}
default:
ZEN_UNREACHABLE
}
return Out;
}
std::tuple<ConstraintSet, Type*> Checker::inferTypeExpr(TypeEnv& Env, TypeExpression* TE) {
@ -296,7 +362,8 @@ ConstraintSet Checker::inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaratio
for (auto It = Params.end(); It-- != Params.begin(); ) {
auto Param = *It;
auto ParamTy = createTVar();
visitPattern(Param->Pattern, ParamTy, NewEnv);
auto ParamOut = visitPattern(Param->Pattern, ParamTy, NewEnv);
mergeTo(Out, ParamOut);
Ty = new TFun(ParamTy, Ty);
}
@ -418,11 +485,13 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
Node* From;
void visitReferenceExpression(ReferenceExpression* E) {
auto To = E->getScope()->lookup(E->getSymbolPath());
if (To != nullptr) {
if (To) {
if (isa<Parameter>(To)) {
To = To->Parent;
}
G.add_edge(From, To);
if (isa<FunctionDeclaration>(To) || isa<VariableDeclaration>(To)) {
G.add_edge(From, To);
}
}
}
} V { {}, G, From };
@ -433,16 +502,16 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
for (auto Element: Elements) {
if (isa<FunctionDeclaration>(Element)) {
auto M = static_cast<FunctionDeclaration*>(Element);
G.add_vertex(Element);
if (M->hasBody()) {
populate(M, M->getBody());
auto Decl = static_cast<FunctionDeclaration*>(Element);
G.add_vertex(Decl);
if (Decl->hasBody()) {
populate(Decl, Decl->getBody());
}
} else if (isa<VariableDeclaration>(Element)) {
auto M = static_cast<VariableDeclaration*>(Element);
G.add_vertex(Element);
if (M->hasExpression()) {
populate(M, M->getExpression());
auto Decl = static_cast<VariableDeclaration*>(Element);
G.add_vertex(Decl);
if (Decl->hasExpression()) {
populate(Decl, Decl->getExpression());
}
} else {
Stmts.push_back(Element);
@ -545,19 +614,23 @@ ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected,
case NodeKind::FunctionExpression:
{
ConstraintSet Out;
auto E = static_cast<FunctionExpression*>(Expr);
// FIXME save RetTy on the node and re-use it in this function?
if (Expected->getKind() == TypeKind::Fun) {
TypeEnv NewEnv { Env };
TFun* Ty = static_cast<TFun*>(Expected);
for (auto P: E->getParameters()) {
visitPattern(P, Ty->getLeft(), NewEnv);
auto ParamOut = visitPattern(P, Ty->getLeft(), NewEnv);
mergeTo(Out, ParamOut);
if (Ty->getRight()->getKind() != TypeKind::Fun) {
goto fallback;
}
Ty = static_cast<TFun*>(Ty->getRight());
}
return checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight());
auto ExprOut = checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight());
mergeTo(Out, ExprOut);
return Out;
}
goto fallback;
}
@ -619,6 +692,8 @@ void Checker::unifyTypeType(Type* A, Type* B, Node* N) {
void Checker::run(SourceFile* SF) {
TypeEnv Env;
auto A = createTVar();
auto B = createTVar();
Env.add("Int", getIntType(), SymbolKind::Type);
Env.add("Bool", getBoolType(), SymbolKind::Type);
Env.add("String", getStringType(), SymbolKind::Type);
@ -626,6 +701,7 @@ void Checker::run(SourceFile* SF) {
Env.add("False", getBoolType(), SymbolKind::Var);
Env.add("+", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var);
Env.add("-", new TFun(getIntType(), new TFun(getIntType(), getIntType())), SymbolKind::Var);
Env.add("$", new TypeScheme({ A, B }, new TFun(new TFun(A, B), new TFun(A, B))), SymbolKind::Var);
auto Out = inferSourceFile(Env, SF);
solve(Out);
}