786 lines
21 KiB
C++
786 lines
21 KiB
C++
|
|
#include "bolt/CSTVisitor.hpp"
|
|
#include "zen/graph.hpp"
|
|
|
|
#include "bolt/ByteString.hpp"
|
|
#include "bolt/CST.hpp"
|
|
#include "bolt/Type.hpp"
|
|
#include "bolt/Diagnostics.hpp"
|
|
#include <algorithm>
|
|
#include <cwchar>
|
|
#include <functional>
|
|
#include <variant>
|
|
#include "bolt/Checker.hpp"
|
|
|
|
namespace bolt {
|
|
|
|
static inline void mergeTo(ConstraintSet& Out, const ConstraintSet& Other) {
|
|
for (auto C: Other) {
|
|
Out.push_back(C);
|
|
}
|
|
}
|
|
|
|
TypeScheme* TypeEnv::lookup(ByteString Name, SymbolKind Kind) {
|
|
auto Curr = this;
|
|
do {
|
|
auto Match = Curr->Mapping.find(std::make_tuple(Name, Kind));
|
|
if (Match != Curr->Mapping.end()) {
|
|
return Match->second;
|
|
}
|
|
Curr = Curr->Parent;
|
|
} while (Curr);
|
|
return nullptr;
|
|
}
|
|
|
|
void TypeEnv::add(ByteString Name, TypeScheme* Scm, SymbolKind Kind) {
|
|
Mapping.emplace(std::make_tuple(Name, Kind), Scm);
|
|
}
|
|
|
|
void TypeEnv::add(ByteString Name, Type* Ty, SymbolKind Kind) {
|
|
add(Name, new TypeScheme { {}, Ty }, Kind);
|
|
}
|
|
|
|
using TVSub = std::unordered_map<TVar*, Type*>;
|
|
|
|
Type* substituteType(Type* Ty, const TVSub& Sub) {
|
|
switch (Ty->getKind()) {
|
|
case TypeKind::App:
|
|
{
|
|
auto A = static_cast<TApp*>(Ty);
|
|
auto NewLeft = substituteType(A->getLeft(), Sub);
|
|
auto NewRight = substituteType(A->getRight(), Sub);
|
|
if (A->getLeft() == NewLeft && A->getRight() == NewRight) {
|
|
return Ty;
|
|
}
|
|
return new TApp(NewLeft, NewRight);
|
|
}
|
|
case TypeKind::Con:
|
|
return Ty;
|
|
case TypeKind::Var:
|
|
{
|
|
auto NewTy = Ty->find();
|
|
if (NewTy->getKind() != TypeKind::Var) {
|
|
return substituteType(NewTy, Sub);
|
|
}
|
|
auto Match = Sub.find(static_cast<TVar*>(NewTy));
|
|
return Match == Sub.end()
|
|
? NewTy
|
|
: Match->second;
|
|
}
|
|
case TypeKind::Fun:
|
|
{
|
|
auto F = static_cast<TFun*>(Ty);
|
|
auto NewLeft = substituteType(F->getLeft(), Sub);
|
|
auto NewRight = substituteType(F->getRight(), Sub);
|
|
if (F->getLeft() == NewLeft && F->getRight() == NewRight) {
|
|
return Ty;
|
|
}
|
|
return new TFun(NewLeft, NewRight);
|
|
}
|
|
}
|
|
}
|
|
|
|
Checker::Checker(DiagnosticEngine& DE):
|
|
DE(DE) {
|
|
IntType = new TCon("Int");
|
|
BoolType = new TCon("Bool");
|
|
StringType = new TCon("String");
|
|
UnitType = new TCon("()");
|
|
}
|
|
|
|
Type* Checker::instantiate(TypeScheme* Scm) {
|
|
TVSub Sub;
|
|
for (auto TV: Scm->Unbound) {
|
|
auto Fresh = createTVar();
|
|
Sub[TV] = Fresh;
|
|
}
|
|
return substituteType(Scm->getType(), Sub);
|
|
}
|
|
|
|
std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Expr, Type* RetTy) {
|
|
|
|
ConstraintSet Out;
|
|
Type* Ty;
|
|
|
|
for (auto Ann: Expr->Annotations) {
|
|
if (Ann->getKind() == NodeKind::TypeAssertAnnotation) {
|
|
auto [AnnOut, AnnTy] = inferTypeExpr(Env, static_cast<TypeAssertAnnotation*>(Ann)->getTypeExpression());
|
|
mergeTo(Out, AnnOut);
|
|
}
|
|
}
|
|
|
|
switch (Expr->getKind()) {
|
|
|
|
case NodeKind::MatchExpression:
|
|
{
|
|
auto E = static_cast<MatchExpression*>(Expr);
|
|
Type* MatchTy;
|
|
if (E->hasValue()) {
|
|
auto [ValOut, ValTy] = inferExpr(Env, E->getValue(), RetTy);
|
|
mergeTo(Out, ValOut);
|
|
MatchTy = ValTy;
|
|
} else {
|
|
MatchTy = createTVar();
|
|
}
|
|
Ty = createTVar();
|
|
for (auto Case: E->Cases) {
|
|
TypeEnv NewEnv { Env };
|
|
auto PattOut = visitPattern(Case->Pattern, MatchTy, NewEnv);
|
|
mergeTo(Out, PattOut);
|
|
auto [ExprOut, ExprTy] = inferExpr(NewEnv, Case->Expression, RetTy);
|
|
mergeTo(Out, ExprOut);
|
|
Out.push_back(new CTypesEqual { ExprTy, Ty, Case->Expression });
|
|
}
|
|
if (!E->Value) {
|
|
auto ParamTy = createTVar();
|
|
Ty = new TFun(ParamTy, Ty);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case NodeKind::NestedExpression:
|
|
{
|
|
auto E = static_cast<NestedExpression*>(Expr);
|
|
auto [ExprOut, ExprTy] = inferExpr(Env, E->Inner, RetTy);
|
|
mergeTo(Out, ExprOut);
|
|
Ty = ExprTy;
|
|
break;
|
|
}
|
|
|
|
case NodeKind::FunctionExpression:
|
|
{
|
|
auto E = static_cast<FunctionExpression*>(Expr);
|
|
Type* NewRetTy = createTVar();
|
|
Ty = NewRetTy;
|
|
TypeEnv NewEnv { Env };
|
|
for (auto P: E->getParameters()) {
|
|
auto TV = createTVar();
|
|
auto ParamOut = visitPattern(P, TV, NewEnv);
|
|
mergeTo(Out, ParamOut);
|
|
Ty = new TFun(TV, Ty);
|
|
}
|
|
auto [ExprOut, ExprTy] = inferExpr(NewEnv, E->getExpression(), NewRetTy);
|
|
mergeTo(Out, ExprOut);
|
|
Out.push_back(new CTypesEqual { ExprTy, NewRetTy, E });
|
|
break;
|
|
}
|
|
|
|
case NodeKind::BlockExpression:
|
|
{
|
|
auto E = static_cast<BlockExpression*>(Expr);
|
|
auto N = E->Elements.size();
|
|
for (std::size_t I = 0; I+1 < N; ++I) {
|
|
auto Element = E->Elements[I];
|
|
auto ElementOut = inferElement(Env, Element, RetTy);
|
|
mergeTo(Out, ElementOut);
|
|
}
|
|
auto Last = E->Elements[N-1];
|
|
auto [LastOut, LastTy] = inferExpr(Env, cast<Expression>(Last), RetTy);
|
|
mergeTo(Out, LastOut);
|
|
Ty = LastTy;
|
|
break;
|
|
}
|
|
|
|
case NodeKind::ReferenceExpression:
|
|
{
|
|
auto E = static_cast<ReferenceExpression*>(Expr);
|
|
auto Name = E->Name.getCanonicalText();
|
|
auto Match = Env.lookup(Name, SymbolKind::Var);
|
|
if (Match == nullptr) {
|
|
DE.add<BindingNotFoundDiagnostic>(Name, E->Name);
|
|
Ty = createTVar();
|
|
} else {
|
|
Ty = instantiate(Match);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case NodeKind::LiteralExpression:
|
|
{
|
|
auto E = static_cast<LiteralExpression*>(Expr);
|
|
switch (E->Token ->getKind()) {
|
|
case NodeKind::IntegerLiteral:
|
|
Ty = getIntType();
|
|
break;
|
|
case NodeKind::StringLiteral:
|
|
Ty = getStringType();
|
|
break;
|
|
default:
|
|
ZEN_UNREACHABLE
|
|
}
|
|
break;
|
|
}
|
|
|
|
case NodeKind::CallExpression:
|
|
{
|
|
auto E = static_cast<CallExpression*>(Expr);
|
|
auto RetTy = createTVar();
|
|
Type* FunTy = RetTy;
|
|
for (auto It = E->Args.end(); It-- != E->Args.begin();) {
|
|
auto [ArgOut, ArgTy] = inferExpr(Env, *It, RetTy);
|
|
mergeTo(Out, ArgOut);
|
|
FunTy = new TFun(ArgTy, FunTy);
|
|
}
|
|
auto FunOut = checkExpr(Env, E->Function, FunTy, RetTy);
|
|
mergeTo(Out, FunOut);
|
|
Ty = RetTy;
|
|
break;
|
|
}
|
|
|
|
case NodeKind::InfixExpression:
|
|
{
|
|
auto E = static_cast<InfixExpression*>(Expr);
|
|
auto [LeftOut, LeftTy] = inferExpr(Env, E->Left, RetTy);
|
|
mergeTo(Out, LeftOut);
|
|
auto [RightOut, RightTy] = inferExpr(Env, E->Right, RetTy);
|
|
mergeTo(Out, RightOut);
|
|
auto Name = E->Operator.getCanonicalText();
|
|
auto Match = Env.lookup(Name, SymbolKind::Var);
|
|
if (Match == nullptr) {
|
|
DE.add<BindingNotFoundDiagnostic>(Name, E->Operator);
|
|
return { Out, createTVar() };
|
|
}
|
|
auto RetTy = createTVar();
|
|
auto FunTy = new TFun(LeftTy, new TFun(RightTy, RetTy));
|
|
Out.push_back(new CTypesEqual(FunTy, instantiate(Match), E));
|
|
Ty = RetTy;
|
|
break;
|
|
}
|
|
|
|
case NodeKind::ReturnExpression:
|
|
{
|
|
auto E = static_cast<ReturnExpression*>(Expr);
|
|
if (E->hasExpression()) {
|
|
auto [ValOut, ValTy] = inferExpr(Env, E->getExpression(), RetTy);
|
|
mergeTo(Out, ValOut);
|
|
// Since evaluation stops at the return expression, it can be matched with any type.
|
|
Out.push_back(new CTypesEqual { ValTy, RetTy, E });
|
|
} else {
|
|
Out.push_back(new CTypesEqual { getUnitType(), RetTy, E });
|
|
}
|
|
Ty = createTVar();
|
|
break;
|
|
}
|
|
|
|
// TODO LambdaExpression
|
|
|
|
default:
|
|
ZEN_UNREACHABLE
|
|
|
|
}
|
|
|
|
Expr->setType(Ty);
|
|
|
|
return { Out, Ty };
|
|
}
|
|
|
|
ConstraintSet Checker::visitPattern(Pattern* P, Type* Ty, TypeEnv& ToInsert) {
|
|
|
|
ConstraintSet Out;
|
|
|
|
switch (P->getKind()) {
|
|
|
|
case NodeKind::BindPattern:
|
|
{
|
|
auto Q = static_cast<BindPattern*>(P);
|
|
// TODO Make a TypedNode out of a Pattern?
|
|
ToInsert.add(Q->Name->getCanonicalText(), Ty, SymbolKind::Var);
|
|
break;
|
|
}
|
|
|
|
case NodeKind::LiteralPattern:
|
|
{
|
|
auto Lit = static_cast<LiteralPattern*>(P);
|
|
Type* LitTy;
|
|
switch (Lit->Literal->getKind()) {
|
|
case NodeKind::StringLiteral:
|
|
LitTy = getStringType();
|
|
break;
|
|
case NodeKind::IntegerLiteral:
|
|
LitTy = getIntType();
|
|
break;
|
|
default:
|
|
ZEN_UNREACHABLE
|
|
}
|
|
Out.push_back(new CTypesEqual { Ty, LitTy, Lit });
|
|
break;
|
|
}
|
|
|
|
default:
|
|
ZEN_UNREACHABLE
|
|
|
|
}
|
|
|
|
return Out;
|
|
|
|
}
|
|
|
|
std::tuple<ConstraintSet, Type*> Checker::inferTypeExpr(TypeEnv& Env, TypeExpression* TE) {
|
|
|
|
ConstraintSet Out;
|
|
Type* Ty;
|
|
|
|
switch (TE->getKind()) {
|
|
|
|
case NodeKind::ReferenceTypeExpression:
|
|
{
|
|
auto Ref = static_cast<ReferenceTypeExpression*>(TE);
|
|
auto Name = Ref->Name->getCanonicalText();
|
|
auto Match = Env.lookup(Name, SymbolKind::Type);
|
|
if (Match == nullptr) {
|
|
DE.add<BindingNotFoundDiagnostic>(Name, Ref->Name);
|
|
Ty = createTVar();
|
|
} else {
|
|
Ty = instantiate(Match);
|
|
}
|
|
break;
|
|
}
|
|
|
|
case NodeKind::ArrowTypeExpression:
|
|
{
|
|
auto Arrow = static_cast<ArrowTypeExpression*>(TE);
|
|
auto [ReturnOut, ReturnTy] = inferTypeExpr(Env, Arrow->ReturnType);
|
|
Ty = ReturnTy;
|
|
for (auto PT: Arrow->ParamTypes) {
|
|
auto [ParamOut, ParamTy] = inferTypeExpr(Env, PT);
|
|
mergeTo(Out, ParamOut);
|
|
Ty = new TFun(ParamTy, Ty);
|
|
}
|
|
break;
|
|
}
|
|
|
|
default:
|
|
ZEN_UNREACHABLE
|
|
|
|
}
|
|
|
|
TE->setType(Ty);
|
|
|
|
return { Out, Ty };
|
|
}
|
|
|
|
ConstraintSet Checker::inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaration* D) {
|
|
|
|
auto TA = D->getTypeAssert();
|
|
auto Params = D->getParams();
|
|
auto Body = D->getBody();
|
|
|
|
ConstraintSet Out;
|
|
|
|
TypeEnv NewEnv { Env };
|
|
|
|
auto RetTy = createTVar();
|
|
|
|
Type* Ty = RetTy;
|
|
for (auto It = Params.end(); It-- != Params.begin(); ) {
|
|
auto Param = *It;
|
|
auto ParamTy = createTVar();
|
|
auto ParamOut = visitPattern(Param->Pattern, ParamTy, NewEnv);
|
|
mergeTo(Out, ParamOut);
|
|
Ty = new TFun(ParamTy, Ty);
|
|
}
|
|
|
|
if (TA != nullptr) {
|
|
auto [TEOut, TETy] = inferTypeExpr(Env, TA->TypeExpression);
|
|
mergeTo(Out, TEOut);
|
|
Out.push_back(new CTypesEqual(Ty, TETy, TA->TypeExpression));
|
|
}
|
|
|
|
if (Body != nullptr) {
|
|
// TODO elminate BlockBody and replace with BlockExpr
|
|
ZEN_ASSERT(Body->getKind() == NodeKind::LetExprBody);
|
|
auto [BodyOut, BodyTy] = inferExpr(NewEnv, cast<LetExprBody>(Body)->Expression, RetTy);
|
|
mergeTo(Out, BodyOut);
|
|
Out.push_back(new CTypesEqual(RetTy, BodyTy, Body));
|
|
}
|
|
|
|
// inferMany() will have set the type of the node to a fresh type variable.
|
|
Out.push_back(new CTypesEqual { D->getType(), Ty, D });
|
|
|
|
return Out;
|
|
}
|
|
|
|
ConstraintSet Checker::inferVariableDeclaration(TypeEnv& Env, VariableDeclaration* Decl, Type* RetTy) {
|
|
|
|
ConstraintSet Out;
|
|
|
|
Type* Ty = nullptr;
|
|
|
|
if (Decl->TypeAssert != nullptr) {
|
|
auto [AssertOut, AssertTy] = inferTypeExpr(Env, Decl->TypeAssert->TypeExpression);
|
|
mergeTo(Out, AssertOut);
|
|
Ty = AssertTy;
|
|
}
|
|
|
|
if (Decl->Body != nullptr) {
|
|
// TODO elminate BlockBody and replace with BlockExpr
|
|
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
|
|
auto [BodyOut, BodyTy] = inferExpr(Env, cast<LetExprBody>(Decl->Body)->Expression, RetTy);
|
|
mergeTo(Out, BodyOut);
|
|
if (Ty == nullptr) {
|
|
Ty = BodyTy;
|
|
} else {
|
|
Out.push_back(new CTypesEqual(Ty, BodyTy, Decl->Body));
|
|
}
|
|
}
|
|
|
|
// Currently we don't perform generalisation on variable declarations
|
|
Env.add(Decl->getNameAsString(), Ty, SymbolKind::Var);
|
|
|
|
return Out;
|
|
}
|
|
|
|
bool hasTypeVar(Type* Ty, TVar* TV) {
|
|
switch (TV->getKind()) {
|
|
case TypeKind::App:
|
|
{
|
|
auto T = static_cast<TApp*>(Ty);
|
|
return hasTypeVar(T->getLeft(), TV)
|
|
|| hasTypeVar(T->getRight(), TV);
|
|
}
|
|
case TypeKind::Con:
|
|
return false;
|
|
case TypeKind::Fun:
|
|
{
|
|
auto T = static_cast<TFun*>(Ty);
|
|
return hasTypeVar(T->getLeft(), TV)
|
|
|| hasTypeVar(T->getRight(), TV);
|
|
}
|
|
case TypeKind::Var:
|
|
{
|
|
auto T = static_cast<TVar*>(Ty);
|
|
return T->find() == TV;
|
|
}
|
|
}
|
|
}
|
|
|
|
bool TypeEnv::hasVar(TVar* TV) const {
|
|
for (auto [_, Scm]: Mapping) {
|
|
if (Scm->Unbound.count(TV)) {
|
|
// FIXME
|
|
ZEN_UNREACHABLE
|
|
}
|
|
if (hasTypeVar(Scm->getType(), TV)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
auto getUnbound(const TypeEnv& Env, Type* Ty) {
|
|
struct Visitor : public TypeVisitor {
|
|
const TypeEnv& Env;
|
|
Visitor(const TypeEnv& Env):
|
|
Env(Env) {}
|
|
std::vector<TVar*> Out;
|
|
void visitVar(TVar* TV) {
|
|
auto Solved = TV->find();
|
|
if (isa<TVar>(Solved)) {
|
|
auto Var = static_cast<TVar*>(Solved);
|
|
if (!Env.hasVar(Var)) {
|
|
Out.push_back(Var);
|
|
}
|
|
} else {
|
|
visit(Solved);
|
|
}
|
|
}
|
|
} V { Env };
|
|
V.visit(Ty);
|
|
return V.Out;
|
|
}
|
|
|
|
ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Type* RetTy) {
|
|
|
|
using Graph = zen::hash_graph<Node*>;
|
|
|
|
TypeEnv NewEnv { Env };
|
|
|
|
Graph G;
|
|
|
|
std::function<void(Node*, Node*)> populate = [&](auto From, auto N) {
|
|
struct Visitor : CSTVisitor<Visitor> {
|
|
Graph& G;
|
|
Node* From;
|
|
void visitReferenceExpression(ReferenceExpression* E) {
|
|
auto To = E->getScope()->lookup(E->getSymbolPath());
|
|
if (To) {
|
|
if (isa<Parameter>(To)) {
|
|
To = To->Parent;
|
|
}
|
|
if (isa<FunctionDeclaration>(To) || isa<VariableDeclaration>(To)) {
|
|
G.add_edge(From, To);
|
|
}
|
|
}
|
|
}
|
|
} V { {}, G, From };
|
|
V.visit(N);
|
|
};
|
|
|
|
std::vector<Node*> Stmts;
|
|
|
|
for (auto Element: Elements) {
|
|
if (isa<FunctionDeclaration>(Element)) {
|
|
auto Decl = static_cast<FunctionDeclaration*>(Element);
|
|
G.add_vertex(Decl);
|
|
if (Decl->hasBody()) {
|
|
populate(Decl, Decl->getBody());
|
|
}
|
|
} else if (isa<VariableDeclaration>(Element)) {
|
|
auto Decl = static_cast<VariableDeclaration*>(Element);
|
|
G.add_vertex(Decl);
|
|
if (Decl->hasExpression()) {
|
|
populate(Decl, Decl->getExpression());
|
|
}
|
|
} else {
|
|
Stmts.push_back(Element);
|
|
}
|
|
}
|
|
|
|
for (auto Mutual: zen::toposort(G)) {
|
|
|
|
ConstraintSet Out;
|
|
|
|
for (auto N: Mutual) {
|
|
if (isa<FunctionDeclaration>(N)) {
|
|
auto Func = static_cast<FunctionDeclaration*>(N);
|
|
Type* Ty = createTVar();
|
|
Func->setType(Ty);
|
|
Env.add(Func->getNameAsString(), Ty, SymbolKind::Var);
|
|
}
|
|
}
|
|
|
|
for (auto N: Mutual) {
|
|
if (isa<FunctionDeclaration>(N)) {
|
|
mergeTo(Out, inferFunctionDeclaration(Env, static_cast<FunctionDeclaration*>(N)));
|
|
} else if (isa<VariableDeclaration>(N)) {
|
|
mergeTo(Out, inferVariableDeclaration(Env, static_cast<VariableDeclaration*>(N), RetTy));
|
|
} else {
|
|
ZEN_UNREACHABLE
|
|
}
|
|
}
|
|
|
|
solve(Out);
|
|
|
|
for (auto N: Mutual) {
|
|
if (isa<FunctionDeclaration>(N)) {
|
|
auto Func = static_cast<FunctionDeclaration*>(N);
|
|
auto Unbound = getUnbound(Env, Func->getType());
|
|
Env.add(
|
|
Func->getNameAsString(),
|
|
new TypeScheme { { Unbound.begin(), Unbound.end() }, Func->getType()->find() },
|
|
SymbolKind::Var
|
|
);
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
ConstraintSet Out;
|
|
|
|
for (auto Stmt: Stmts) {
|
|
mergeTo(Out, inferElement(Env, Stmt, RetTy));
|
|
}
|
|
|
|
return Out;
|
|
}
|
|
|
|
ConstraintSet Checker::inferElement(TypeEnv& Env, Node* N, Type* RetTy) {
|
|
|
|
if (isa<Expression>(N)) {
|
|
auto [Out, Ty] = inferExpr(Env, cast<Expression>(N), RetTy);
|
|
return Out;
|
|
}
|
|
|
|
switch (N->getKind()) {
|
|
|
|
case NodeKind::PrefixFunctionDeclaration:
|
|
case NodeKind::InfixFunctionDeclaration:
|
|
case NodeKind::SuffixFunctionDeclaration:
|
|
case NodeKind::NamedFunctionDeclaration:
|
|
return inferFunctionDeclaration(Env, static_cast<FunctionDeclaration*>(N));
|
|
|
|
case NodeKind::ReturnExpression:
|
|
{
|
|
auto M = static_cast<ReturnExpression*>(N);
|
|
if (!M->hasExpression()) {
|
|
return {};
|
|
}
|
|
auto [ValOut, ValTy] = inferExpr(Env, M->getExpression(), RetTy);
|
|
return { new CTypesEqual(ValTy, RetTy, N) };
|
|
}
|
|
|
|
default:
|
|
ZEN_UNREACHABLE
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ConstraintSet Checker::inferSourceFile(TypeEnv& Env, SourceFile* SF) {
|
|
return inferMany(Env, SF->Elements, nullptr);
|
|
}
|
|
|
|
ConstraintSet Checker::checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected, Type* RetTy) {
|
|
|
|
switch (Expr->getKind()) {
|
|
|
|
case NodeKind::LiteralExpression:
|
|
{
|
|
auto E = static_cast<LiteralExpression*>(Expr);
|
|
switch (E->Token->getKind()) {
|
|
case NodeKind::IntegerLiteral:
|
|
if (*Expected == *getIntType()) {
|
|
return {};
|
|
}
|
|
break;
|
|
case NodeKind::StringLiteral:
|
|
if (*Expected == *getStringType()) {
|
|
return {};
|
|
}
|
|
break;
|
|
default:
|
|
ZEN_UNREACHABLE;
|
|
}
|
|
goto fallback;
|
|
}
|
|
|
|
case NodeKind::FunctionExpression:
|
|
{
|
|
ConstraintSet Out;
|
|
auto E = static_cast<FunctionExpression*>(Expr);
|
|
// FIXME save RetTy on the node and re-use it in this function?
|
|
if (Expected->getKind() == TypeKind::Fun) {
|
|
TypeEnv NewEnv { Env };
|
|
TFun* Ty = static_cast<TFun*>(Expected);
|
|
for (auto P: E->getParameters()) {
|
|
auto ParamOut = visitPattern(P, Ty->getLeft(), NewEnv);
|
|
mergeTo(Out, ParamOut);
|
|
if (Ty->getRight()->getKind() != TypeKind::Fun) {
|
|
goto fallback;
|
|
}
|
|
Ty = static_cast<TFun*>(Ty->getRight());
|
|
}
|
|
auto ExprOut = checkExpr(NewEnv, E->getExpression(), Ty->getRight(), Ty->getRight());
|
|
mergeTo(Out, ExprOut);
|
|
return Out;
|
|
}
|
|
goto fallback;
|
|
}
|
|
|
|
default:
|
|
{
|
|
fallback:
|
|
auto [Out, Actual] = inferExpr(Env, Expr, RetTy);
|
|
Out.push_back(new CTypesEqual(Actual, Expected, Expr));
|
|
return Out;
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
void Checker::solve(const std::vector<Constraint*>& Constraints) {
|
|
for (auto C: Constraints) {
|
|
switch (C->getKind()) {
|
|
case ConstraintKind::TypesEqual:
|
|
{
|
|
auto D = static_cast<CTypesEqual*>(C);
|
|
unifyTypeType(D->getLeft(), D->getRight(), D->getOrigin());
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void Checker::unifyTypeType(Type* A, Type* B, Node* N) {
|
|
A = A->find();
|
|
B = B->find();
|
|
if (A->getKind() == TypeKind::Var) {
|
|
auto TV = static_cast<TVar*>(A);
|
|
// TODO occurs check
|
|
TV->set(B);
|
|
return;
|
|
}
|
|
if (B->getKind() == TypeKind::Var) {
|
|
unifyTypeType(B, A, N);
|
|
return;
|
|
}
|
|
if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) {
|
|
auto C1 = static_cast<TCon*>(A);
|
|
auto C2 = static_cast<TCon*>(B);
|
|
if (C1->getName() == C2->getName()) {
|
|
return;
|
|
}
|
|
}
|
|
if (A->getKind() == TypeKind::Fun && B->getKind() == TypeKind::Fun) {
|
|
auto F1 = static_cast<TFun*>(A);
|
|
auto F2 = static_cast<TFun*>(B);
|
|
unifyTypeType(F1->getLeft(), F2->getLeft(), N);
|
|
unifyTypeType(F1->getRight(), F2->getRight(), N);
|
|
return;
|
|
}
|
|
DE.add<TypeMismatchError>(A, B, N);
|
|
}
|
|
|
|
void Checker::run(SourceFile* SF) {
|
|
TypeEnv Env;
|
|
auto A = createTVar();
|
|
auto B = createTVar();
|
|
auto Bool = getBoolType();
|
|
auto Int = getIntType();
|
|
auto String = getStringType();
|
|
Env.add("Int", Int, SymbolKind::Type);
|
|
Env.add("Bool", Bool, SymbolKind::Type);
|
|
Env.add("String", String, SymbolKind::Type);
|
|
Env.add("True", Bool, SymbolKind::Var);
|
|
Env.add("False", Bool, SymbolKind::Var);
|
|
Env.add("not", new TFun(Bool, Bool), SymbolKind::Var);
|
|
Env.add("+", new TFun(Int, new TFun(Int, Int)), SymbolKind::Var);
|
|
Env.add("-", new TFun(Int, new TFun(Int, Int)), SymbolKind::Var);
|
|
Env.add("$", new TypeScheme({ A, B }, new TFun(new TFun(A, B), new TFun(A, B))), SymbolKind::Var);
|
|
auto Out = inferSourceFile(Env, SF);
|
|
solve(Out);
|
|
}
|
|
|
|
Type* resolveType(Type* Ty) {
|
|
switch (Ty->getKind()) {
|
|
case TypeKind::App:
|
|
{
|
|
auto A = static_cast<TApp*>(Ty);
|
|
auto NewLeft = resolveType(A->getLeft());
|
|
auto NewRight = resolveType(A->getRight());
|
|
if (A->getLeft() == NewLeft && A->getRight() == NewRight) {
|
|
return Ty;
|
|
}
|
|
return new TApp(NewLeft, NewRight);
|
|
}
|
|
case TypeKind::Con:
|
|
return Ty;
|
|
case TypeKind::Var:
|
|
{
|
|
auto NewTy = Ty->find();
|
|
if (NewTy->getKind() != TypeKind::Var) {
|
|
return resolveType(NewTy);
|
|
} else {
|
|
return NewTy;
|
|
}
|
|
}
|
|
case TypeKind::Fun:
|
|
{
|
|
auto F = static_cast<TFun*>(Ty);
|
|
auto NewLeft = resolveType(F->getLeft());
|
|
auto NewRight = resolveType(F->getRight());
|
|
if (F->getLeft() == NewLeft && F->getRight() == NewRight) {
|
|
return Ty;
|
|
}
|
|
return new TFun(NewLeft, NewRight);
|
|
}
|
|
}
|
|
}
|
|
|
|
Type* Checker::getTypeOfNode(Node* N) {
|
|
auto M = cast<TypedNode>(N);
|
|
return resolveType(M->getType());
|
|
}
|
|
|
|
}
|