2022-08-21 16:25:52 +02:00
|
|
|
|
|
|
|
#include <stack>
|
|
|
|
|
2022-08-21 20:56:58 +02:00
|
|
|
#include "bolt/Diagnostics.hpp"
|
2022-08-21 16:25:52 +02:00
|
|
|
#include "zen/config.hpp"
|
|
|
|
|
|
|
|
#include "bolt/CST.hpp"
|
|
|
|
#include "bolt/Checker.hpp"
|
|
|
|
|
|
|
|
namespace bolt {
|
|
|
|
|
2022-08-25 19:04:25 +02:00
|
|
|
std::string describe(const Type* Ty);
|
2022-08-21 20:56:58 +02:00
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
void Type::addTypeVars(TVSet& TVs) {
|
|
|
|
switch (Kind) {
|
|
|
|
case TypeKind::Var:
|
|
|
|
TVs.emplace(static_cast<TVar*>(this));
|
|
|
|
break;
|
|
|
|
case TypeKind::Arrow:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TArrow*>(this);
|
|
|
|
for (auto Ty: Y->ParamTypes) {
|
|
|
|
Ty->addTypeVars(TVs);
|
|
|
|
}
|
|
|
|
Y->ReturnType->addTypeVars(TVs);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case TypeKind::Con:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TCon*>(this);
|
|
|
|
for (auto Ty: Y->Args) {
|
|
|
|
Ty->addTypeVars(TVs);
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case TypeKind::Tuple:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TTuple*>(this);
|
|
|
|
for (auto Ty: Y->ElementTypes) {
|
|
|
|
Ty->addTypeVars(TVs);
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case TypeKind::Any:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-21 16:25:52 +02:00
|
|
|
bool Type::hasTypeVar(const TVar* TV) {
|
|
|
|
switch (Kind) {
|
|
|
|
case TypeKind::Var:
|
|
|
|
return static_cast<TVar*>(this)->Id == TV->Id;
|
|
|
|
case TypeKind::Arrow:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TArrow*>(this);
|
|
|
|
for (auto Ty: Y->ParamTypes) {
|
|
|
|
if (Ty->hasTypeVar(TV)) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return Y->ReturnType->hasTypeVar(TV);
|
|
|
|
}
|
2022-08-21 20:56:58 +02:00
|
|
|
case TypeKind::Con:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TCon*>(this);
|
|
|
|
for (auto Ty: Y->Args) {
|
|
|
|
if (Ty->hasTypeVar(TV)) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
2022-08-25 23:04:09 +02:00
|
|
|
case TypeKind::Tuple:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TTuple*>(this);
|
|
|
|
for (auto Ty: Y->ElementTypes) {
|
|
|
|
if (Ty->hasTypeVar(TV)) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
2022-08-21 20:56:58 +02:00
|
|
|
case TypeKind::Any:
|
|
|
|
return false;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Type* Type::substitute(const TVSub &Sub) {
|
|
|
|
switch (Kind) {
|
|
|
|
case TypeKind::Var:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TVar*>(this);
|
|
|
|
auto Match = Sub.find(Y);
|
2022-08-25 19:04:25 +02:00
|
|
|
return Match != Sub.end() ? Match->second->substitute(Sub) : Y;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
case TypeKind::Arrow:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TArrow*>(this);
|
2022-08-26 22:10:18 +02:00
|
|
|
bool Changed = false;
|
2022-08-21 16:25:52 +02:00
|
|
|
std::vector<Type*> NewParamTypes;
|
|
|
|
for (auto Ty: Y->ParamTypes) {
|
2022-08-26 22:10:18 +02:00
|
|
|
auto NewParamType = Ty->substitute(Sub);
|
|
|
|
if (NewParamType != Ty) {
|
|
|
|
Changed = true;
|
|
|
|
}
|
|
|
|
NewParamTypes.push_back(NewParamType);
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
auto NewRetTy = Y->ReturnType->substitute(Sub) ;
|
2022-08-26 22:10:18 +02:00
|
|
|
if (NewRetTy != Y->ReturnType) {
|
|
|
|
Changed = true;
|
|
|
|
}
|
|
|
|
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
case TypeKind::Any:
|
|
|
|
return this;
|
|
|
|
case TypeKind::Con:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TCon*>(this);
|
2022-08-26 22:10:18 +02:00
|
|
|
bool Changed = false;
|
2022-08-21 16:25:52 +02:00
|
|
|
std::vector<Type*> NewArgs;
|
|
|
|
for (auto Arg: Y->Args) {
|
2022-08-26 22:10:18 +02:00
|
|
|
auto NewArg = Arg->substitute(Sub);
|
|
|
|
if (NewArg != Arg) {
|
|
|
|
Changed = true;
|
|
|
|
}
|
|
|
|
NewArgs.push_back(NewArg);
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
return Changed ? new TCon(Y->Id, NewArgs, Y->DisplayName) : this;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
2022-08-25 23:04:09 +02:00
|
|
|
case TypeKind::Tuple:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<TTuple*>(this);
|
2022-08-26 22:10:18 +02:00
|
|
|
bool Changed = false;
|
2022-08-25 23:04:09 +02:00
|
|
|
std::vector<Type*> NewElementTypes;
|
|
|
|
for (auto Ty: Y->ElementTypes) {
|
2022-08-26 22:10:18 +02:00
|
|
|
auto NewElementType = Ty->substitute(Sub);
|
|
|
|
if (NewElementType != Ty) {
|
|
|
|
Changed = true;
|
|
|
|
}
|
|
|
|
NewElementTypes.push_back(NewElementType);
|
2022-08-25 23:04:09 +02:00
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
return Changed ? new TTuple(NewElementTypes) : this;
|
2022-08-25 23:04:09 +02:00
|
|
|
}
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-25 19:04:25 +02:00
|
|
|
Constraint* Constraint::substitute(const TVSub &Sub) {
|
|
|
|
switch (Kind) {
|
|
|
|
case ConstraintKind::Equal:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<CEqual*>(this);
|
|
|
|
return new CEqual(Y->Left->substitute(Sub), Y->Right->substitute(Sub), Y->Source);
|
|
|
|
}
|
|
|
|
case ConstraintKind::Many:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<CMany*>(this);
|
|
|
|
auto NewConstraints = new ConstraintSet();
|
2022-08-26 22:10:18 +02:00
|
|
|
for (auto Element: Y->Elements) {
|
2022-08-25 19:04:25 +02:00
|
|
|
NewConstraints->push_back(Element->substitute(Sub));
|
|
|
|
}
|
|
|
|
return new CMany(*NewConstraints);
|
|
|
|
}
|
|
|
|
case ConstraintKind::Empty:
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
Checker::Checker(DiagnosticEngine& DE):
|
|
|
|
DE(DE) {
|
|
|
|
BoolType = new TCon(nextConTypeId++, {}, "Bool");
|
|
|
|
IntType = new TCon(nextConTypeId++, {}, "Int");
|
|
|
|
StringType = new TCon(nextConTypeId++, {}, "String");
|
|
|
|
}
|
|
|
|
|
|
|
|
Scheme* Checker::lookup(ByteString Name) {
|
|
|
|
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
|
|
|
|
auto Curr = *Iter;
|
2022-08-25 19:04:25 +02:00
|
|
|
auto Match = Curr->Env.find(Name);
|
|
|
|
if (Match != Curr->Env.end()) {
|
|
|
|
return &Match->second;
|
|
|
|
}
|
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
return nullptr;
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
Type* Checker::lookupMono(ByteString Name) {
|
2022-08-25 19:04:25 +02:00
|
|
|
auto Scm = lookup(Name);
|
|
|
|
if (Scm == nullptr) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
auto& F = Scm->as<Forall>();
|
|
|
|
ZEN_ASSERT(F.TVs == nullptr || F.TVs->empty());
|
|
|
|
return F.Type;
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
void Checker::addBinding(ByteString Name, Scheme S) {
|
|
|
|
Contexts.back()->Env.emplace(Name, S);
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
Type* Checker::getReturnType() {
|
|
|
|
auto Ty = Contexts.back()->ReturnType;
|
|
|
|
ZEN_ASSERT(Ty != nullptr);
|
|
|
|
return Ty;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
static bool hasTypeVar(TVSet& Set, Type* Type) {
|
|
|
|
for (auto TV: Type->getTypeVars()) {
|
|
|
|
if (Set.count(TV)) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
void Checker::addConstraint(Constraint* Constraint) {
|
|
|
|
switch (Constraint->getKind()) {
|
|
|
|
case ConstraintKind::Equal:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<CEqual*>(Constraint);
|
|
|
|
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
|
|
|
|
auto& Ctx = **Iter;
|
|
|
|
if (hasTypeVar(Ctx.TVs, Y->Left) || hasTypeVar(Ctx.TVs, Y->Right)) {
|
|
|
|
Ctx.Constraints.push_back(Constraint);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Contexts.front()->Constraints.push_back(Constraint);
|
|
|
|
//auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth);
|
|
|
|
//ZEN_ASSERT(I < Contexts.size());
|
|
|
|
//auto Ctx = Contexts[I];
|
|
|
|
//Ctx->Constraints.push_back(Constraint);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case ConstraintKind::Many:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<CMany*>(Constraint);
|
|
|
|
for (auto Element: Y->Elements) {
|
|
|
|
addConstraint(Element);
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case ConstraintKind::Empty:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void Checker::forwardDeclare(Node* X) {
|
|
|
|
|
|
|
|
switch (X->Type) {
|
|
|
|
|
|
|
|
case NodeType::ExpressionStatement:
|
|
|
|
case NodeType::ReturnStatement:
|
|
|
|
case NodeType::IfStatement:
|
|
|
|
break;
|
|
|
|
|
|
|
|
case NodeType::SourceFile:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<SourceFile*>(X);
|
|
|
|
for (auto Element: Y->Elements) {
|
|
|
|
forwardDeclare(Element) ;
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
case NodeType::LetDeclaration:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<LetDeclaration*>(X);
|
|
|
|
|
|
|
|
auto NewCtx = new InferContext();
|
|
|
|
Y->Ctx = NewCtx;
|
|
|
|
|
|
|
|
Contexts.push_back(NewCtx);
|
|
|
|
|
|
|
|
Type* Ty;
|
|
|
|
if (Y->TypeAssert) {
|
|
|
|
Ty = inferTypeExpression(Y->TypeAssert->TypeExpression);
|
|
|
|
} else {
|
|
|
|
Ty = createTypeVar();
|
|
|
|
}
|
|
|
|
Y->Ty = Ty;
|
|
|
|
|
|
|
|
if (Y->Body) {
|
|
|
|
switch (Y->Body->Type) {
|
|
|
|
case NodeType::LetExprBody:
|
|
|
|
break;
|
|
|
|
case NodeType::LetBlockBody:
|
|
|
|
{
|
|
|
|
auto Z = static_cast<LetBlockBody*>(Y->Body);
|
2023-04-12 11:15:36 +02:00
|
|
|
NewCtx->ReturnType = createTypeVar();
|
2022-08-26 22:10:18 +02:00
|
|
|
for (auto Element: Z->Elements) {
|
|
|
|
forwardDeclare(Element);
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
ZEN_UNREACHABLE
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Contexts.pop_back();
|
|
|
|
|
|
|
|
inferBindings(Y->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs);
|
|
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
|
|
ZEN_UNREACHABLE
|
|
|
|
|
2022-08-25 23:04:09 +02:00
|
|
|
}
|
2022-08-21 20:56:58 +02:00
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
void Checker::infer(Node* X) {
|
2022-08-21 16:25:52 +02:00
|
|
|
|
|
|
|
switch (X->Type) {
|
|
|
|
|
|
|
|
case NodeType::SourceFile:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<SourceFile*>(X);
|
|
|
|
for (auto Element: Y->Elements) {
|
2022-08-26 22:10:18 +02:00
|
|
|
infer(Element);
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
2022-08-25 23:04:09 +02:00
|
|
|
case NodeType::IfStatement:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<IfStatement*>(X);
|
|
|
|
for (auto Part: Y->Parts) {
|
|
|
|
if (Part->Test != nullptr) {
|
2022-08-26 22:10:18 +02:00
|
|
|
addConstraint(new CEqual { BoolType, inferExpression(Part->Test), Part->Test });
|
2022-08-25 23:04:09 +02:00
|
|
|
}
|
|
|
|
for (auto Element: Part->Elements) {
|
2022-08-26 22:10:18 +02:00
|
|
|
infer(Element);
|
2022-08-25 23:04:09 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
2022-08-21 16:25:52 +02:00
|
|
|
case NodeType::LetDeclaration:
|
|
|
|
{
|
2022-08-25 19:04:25 +02:00
|
|
|
auto Y = static_cast<LetDeclaration*>(X);
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
auto NewCtx = Y->Ctx;
|
|
|
|
Contexts.push_back(NewCtx);
|
2022-08-25 19:04:25 +02:00
|
|
|
|
|
|
|
std::vector<Type*> ParamTypes;
|
|
|
|
Type* RetType;
|
|
|
|
|
|
|
|
for (auto Param: Y->Params) {
|
|
|
|
// TODO incorporate Param->TypeAssert or make it a kind of pattern
|
2022-08-26 22:10:18 +02:00
|
|
|
TVar* TV = createTypeVar();
|
2022-08-25 19:04:25 +02:00
|
|
|
TVSet NoTVs;
|
|
|
|
ConstraintSet NoConstraints;
|
2022-08-26 22:10:18 +02:00
|
|
|
inferBindings(Param->Pattern, TV, NoConstraints, NoTVs);
|
2022-08-25 19:04:25 +02:00
|
|
|
ParamTypes.push_back(TV);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (Y->Body) {
|
|
|
|
switch (Y->Body->Type) {
|
|
|
|
case NodeType::LetExprBody:
|
|
|
|
{
|
|
|
|
auto Z = static_cast<LetExprBody*>(Y->Body);
|
2022-08-26 22:10:18 +02:00
|
|
|
RetType = inferExpression(Z->Expression);
|
2022-08-25 19:04:25 +02:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
case NodeType::LetBlockBody:
|
|
|
|
{
|
|
|
|
auto Z = static_cast<LetBlockBody*>(Y->Body);
|
2023-04-12 11:15:36 +02:00
|
|
|
RetType = Y->Ty;
|
2022-08-25 19:04:25 +02:00
|
|
|
for (auto Element: Z->Elements) {
|
2022-08-26 22:10:18 +02:00
|
|
|
infer(Element);
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
ZEN_UNREACHABLE
|
|
|
|
}
|
|
|
|
} else {
|
2022-08-26 22:10:18 +02:00
|
|
|
RetType = createTypeVar();
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
addConstraint(new CEqual { Y->Ty, new TArrow(ParamTypes, RetType), X });
|
2022-08-25 19:04:25 +02:00
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
Contexts.pop_back();
|
2022-08-25 19:04:25 +02:00
|
|
|
|
2022-08-21 16:25:52 +02:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
2022-08-25 23:04:09 +02:00
|
|
|
case NodeType::ReturnStatement:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<ReturnStatement*>(X);
|
|
|
|
Type* ReturnType;
|
|
|
|
if (Y->Expression) {
|
2022-08-26 22:10:18 +02:00
|
|
|
ReturnType = inferExpression(Y->Expression);
|
2022-08-25 23:04:09 +02:00
|
|
|
} else {
|
|
|
|
ReturnType = new TTuple({});
|
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
addConstraint(new CEqual { ReturnType, getReturnType(), X });
|
2022-08-25 23:04:09 +02:00
|
|
|
break;
|
|
|
|
}
|
2022-08-21 16:25:52 +02:00
|
|
|
|
|
|
|
case NodeType::ExpressionStatement:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<ExpressionStatement*>(X);
|
2022-08-26 22:10:18 +02:00
|
|
|
inferExpression(Y->Expression);
|
2022-08-21 16:25:52 +02:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
|
|
ZEN_UNREACHABLE
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
TVar* Checker::createTypeVar() {
|
2022-08-25 19:04:25 +02:00
|
|
|
auto TV = new TVar(nextTypeVarId++);
|
2022-08-26 22:10:18 +02:00
|
|
|
Contexts.back()->TVs.emplace(TV);
|
2022-08-25 19:04:25 +02:00
|
|
|
return TV;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
Type* Checker::instantiate(Scheme& S, Node* Source) {
|
2022-08-21 16:25:52 +02:00
|
|
|
|
|
|
|
switch (S.getKind()) {
|
|
|
|
|
|
|
|
case SchemeKind::Forall:
|
|
|
|
{
|
|
|
|
auto& F = S.as<Forall>();
|
2022-08-25 19:04:25 +02:00
|
|
|
|
2022-08-21 16:25:52 +02:00
|
|
|
TVSub Sub;
|
2022-08-26 22:10:18 +02:00
|
|
|
for (auto TV: *F.TVs) {
|
|
|
|
Sub[TV] = createTypeVar();
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
2022-08-25 19:04:25 +02:00
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
for (auto Constraint: *F.Constraints) {
|
2022-08-25 19:04:25 +02:00
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
auto NewConstraint = Constraint->substitute(Sub);
|
2022-08-25 19:04:25 +02:00
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
// This makes error messages prettier by relating the typing failure
|
|
|
|
// to the call site rather than the definition.
|
|
|
|
if (NewConstraint->getKind() == ConstraintKind::Equal) {
|
|
|
|
static_cast<CEqual *>(NewConstraint)->Source = Source;
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
|
|
|
|
addConstraint(NewConstraint);
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
// FIXME substitute should always clone if we set MaxDepth
|
|
|
|
auto NewType = F.Type->substitute(Sub);
|
|
|
|
//NewType->MaxDepth = std::max(static_cast<unsigned>(Contexts.size()-1), F.Type->MaxDepth);
|
|
|
|
return NewType;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
Type* Checker::inferTypeExpression(TypeExpression* X) {
|
2022-08-25 19:04:25 +02:00
|
|
|
|
|
|
|
switch (X->Type) {
|
|
|
|
|
|
|
|
case NodeType::ReferenceTypeExpression:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<ReferenceTypeExpression*>(X);
|
2022-08-26 22:10:18 +02:00
|
|
|
auto Ty = lookupMono(Y->Name->Name->Text);
|
2022-08-25 19:04:25 +02:00
|
|
|
if (Ty == nullptr) {
|
|
|
|
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name->Name);
|
|
|
|
return new TAny();
|
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
Mapping[X] = Ty;
|
2022-08-25 19:04:25 +02:00
|
|
|
return Ty;
|
|
|
|
}
|
|
|
|
|
|
|
|
case NodeType::ArrowTypeExpression:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<ArrowTypeExpression*>(X);
|
|
|
|
std::vector<Type*> ParamTypes;
|
|
|
|
for (auto ParamType: Y->ParamTypes) {
|
2022-08-26 22:10:18 +02:00
|
|
|
ParamTypes.push_back(inferTypeExpression(ParamType));
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
auto ReturnType = inferTypeExpression(Y->ReturnType);
|
|
|
|
auto Ty = new TArrow(ParamTypes, ReturnType);
|
|
|
|
Mapping[X] = Ty;
|
|
|
|
return Ty;
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
|
|
ZEN_UNREACHABLE
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
2022-08-21 16:25:52 +02:00
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
Type* Checker::inferExpression(Expression* X) {
|
2022-08-21 16:25:52 +02:00
|
|
|
|
|
|
|
switch (X->Type) {
|
|
|
|
|
|
|
|
case NodeType::ConstantExpression:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<ConstantExpression*>(X);
|
2022-08-21 20:56:58 +02:00
|
|
|
Type* Ty = nullptr;
|
2022-08-21 16:25:52 +02:00
|
|
|
switch (Y->Token->Type) {
|
|
|
|
case NodeType::IntegerLiteral:
|
2022-08-26 22:10:18 +02:00
|
|
|
Ty = lookupMono("Int");
|
2022-08-21 20:56:58 +02:00
|
|
|
break;
|
2022-08-21 16:25:52 +02:00
|
|
|
case NodeType::StringLiteral:
|
2022-08-26 22:10:18 +02:00
|
|
|
Ty = lookupMono("String");
|
2022-08-21 20:56:58 +02:00
|
|
|
break;
|
2022-08-21 16:25:52 +02:00
|
|
|
default:
|
|
|
|
ZEN_UNREACHABLE
|
|
|
|
}
|
2022-08-21 20:56:58 +02:00
|
|
|
ZEN_ASSERT(Ty != nullptr);
|
2022-08-26 22:10:18 +02:00
|
|
|
Mapping[X] = Ty;
|
2022-08-21 20:56:58 +02:00
|
|
|
return Ty;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
case NodeType::ReferenceExpression:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<ReferenceExpression*>(X);
|
2022-08-24 12:36:43 +02:00
|
|
|
ZEN_ASSERT(Y->Name->ModulePath.empty());
|
2022-08-26 22:10:18 +02:00
|
|
|
auto Ctx = lookupCall(Y, Y->Name->getSymbolPath());
|
|
|
|
if (Ctx) {
|
2023-04-12 11:15:36 +02:00
|
|
|
/* std::cerr << "recursive call!\n"; */
|
|
|
|
ZEN_ASSERT(Ctx->ReturnType != nullptr);
|
2022-08-26 22:10:18 +02:00
|
|
|
return Ctx->ReturnType;
|
|
|
|
}
|
|
|
|
auto Scm = lookup(Y->Name->Name->Text);
|
2022-08-21 16:25:52 +02:00
|
|
|
if (Scm == nullptr) {
|
2022-08-24 12:36:43 +02:00
|
|
|
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name);
|
2022-08-21 16:25:52 +02:00
|
|
|
return new TAny();
|
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
auto Ty = instantiate(*Scm, X);
|
|
|
|
Mapping[X] = Ty;
|
|
|
|
return Ty;
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
case NodeType::CallExpression:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<CallExpression*>(X);
|
2022-08-26 22:10:18 +02:00
|
|
|
auto OpTy = inferExpression(Y->Function);
|
|
|
|
auto RetType = createTypeVar();
|
2022-08-25 19:04:25 +02:00
|
|
|
std::vector<Type*> ArgTypes;
|
|
|
|
for (auto Arg: Y->Args) {
|
2022-08-26 22:10:18 +02:00
|
|
|
ArgTypes.push_back(inferExpression(Arg));
|
2022-08-25 19:04:25 +02:00
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X });
|
|
|
|
Mapping[X] = RetType;
|
2022-08-25 19:04:25 +02:00
|
|
|
return RetType;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
case NodeType::InfixExpression:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<InfixExpression*>(X);
|
2022-08-26 22:10:18 +02:00
|
|
|
auto Scm = lookup(Y->Operator->getText());
|
2022-08-21 16:25:52 +02:00
|
|
|
if (Scm == nullptr) {
|
2022-08-21 20:56:58 +02:00
|
|
|
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator);
|
2022-08-21 16:25:52 +02:00
|
|
|
return new TAny();
|
|
|
|
}
|
2022-08-26 22:10:18 +02:00
|
|
|
auto OpTy = instantiate(*Scm, Y->Operator);
|
|
|
|
auto RetTy = createTypeVar();
|
2022-08-21 16:25:52 +02:00
|
|
|
std::vector<Type*> ArgTys;
|
2022-08-26 22:10:18 +02:00
|
|
|
ArgTys.push_back(inferExpression(Y->LHS));
|
|
|
|
ArgTys.push_back(inferExpression(Y->RHS));
|
|
|
|
addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X });
|
|
|
|
Mapping[X] = RetTy;
|
2022-08-21 16:25:52 +02:00
|
|
|
return RetTy;
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
case NodeType::NestedExpression:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<NestedExpression*>(X);
|
|
|
|
return inferExpression(Y->Inner);
|
|
|
|
}
|
|
|
|
|
2022-08-21 16:25:52 +02:00
|
|
|
default:
|
|
|
|
ZEN_UNREACHABLE
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
void Checker::inferBindings(Pattern* Pattern, Type* Type, ConstraintSet& Constraints, TVSet& TVs) {
|
2022-08-25 19:04:25 +02:00
|
|
|
|
|
|
|
switch (Pattern->Type) {
|
|
|
|
|
|
|
|
case NodeType::BindPattern:
|
2022-08-26 22:10:18 +02:00
|
|
|
addBinding(static_cast<BindPattern*>(Pattern)->Name->Text, Forall(TVs, Constraints, Type));
|
2022-08-25 19:04:25 +02:00
|
|
|
break;
|
|
|
|
|
|
|
|
default:
|
|
|
|
ZEN_UNREACHABLE
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
TVSub Checker::check(SourceFile *SF) {
|
|
|
|
Contexts.push_back(new InferContext {});
|
|
|
|
ConstraintSet NoConstraints;
|
|
|
|
addBinding("String", Forall(StringType));
|
|
|
|
addBinding("Int", Forall(IntType));
|
|
|
|
addBinding("Bool", Forall(BoolType));
|
|
|
|
addBinding("True", Forall(BoolType));
|
|
|
|
addBinding("False", Forall(BoolType));
|
|
|
|
auto A = createTypeVar();
|
|
|
|
TVSet SingleA { A };
|
|
|
|
addBinding("==", Forall(SingleA, NoConstraints, new TArrow({ A, A }, BoolType)));
|
|
|
|
addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType)));
|
|
|
|
addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType)));
|
|
|
|
addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType)));
|
|
|
|
addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType)));
|
|
|
|
forwardDeclare(SF);
|
|
|
|
infer(SF);
|
|
|
|
TVSub Solution;
|
|
|
|
solve(new CMany(Contexts.front()->Constraints), Solution);
|
|
|
|
Contexts.pop_back();
|
|
|
|
return Solution;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
void Checker::solve(Constraint* Constraint, TVSub& Solution) {
|
2022-08-21 16:25:52 +02:00
|
|
|
|
|
|
|
std::stack<class Constraint*> Queue;
|
2022-08-21 20:56:58 +02:00
|
|
|
Queue.push(Constraint);
|
2022-08-21 16:25:52 +02:00
|
|
|
|
|
|
|
while (!Queue.empty()) {
|
|
|
|
|
|
|
|
auto Constraint = Queue.top();
|
|
|
|
|
|
|
|
Queue.pop();
|
|
|
|
|
|
|
|
switch (Constraint->getKind()) {
|
|
|
|
|
|
|
|
case ConstraintKind::Empty:
|
|
|
|
break;
|
|
|
|
|
|
|
|
case ConstraintKind::Many:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<CMany*>(Constraint);
|
2022-08-26 22:10:18 +02:00
|
|
|
for (auto Constraint: Y->Elements) {
|
2022-08-21 16:25:52 +02:00
|
|
|
Queue.push(Constraint);
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
case ConstraintKind::Equal:
|
|
|
|
{
|
|
|
|
auto Y = static_cast<CEqual*>(Constraint);
|
2023-04-12 11:15:36 +02:00
|
|
|
/* std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl; */
|
2022-08-25 19:04:25 +02:00
|
|
|
if (!unify(Y->Left, Y->Right, Solution)) {
|
|
|
|
DE.add<UnificationErrorDiagnostic>(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source);
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
bool Checker::unify(Type* A, Type* B, TVSub& Solution) {
|
|
|
|
|
|
|
|
if (A->getKind() == TypeKind::Var) {
|
|
|
|
auto Match = Solution.find(static_cast<TVar*>(A));
|
|
|
|
if (Match != Solution.end()) {
|
|
|
|
A = Match->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (B->getKind() == TypeKind::Var) {
|
|
|
|
auto Match = Solution.find(static_cast<TVar*>(B));
|
|
|
|
if (Match != Solution.end()) {
|
|
|
|
B = Match->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (A->getKind() == TypeKind::Var) {
|
|
|
|
auto Y = static_cast<TVar*>(A);
|
|
|
|
if (B->hasTypeVar(Y)) {
|
|
|
|
// TODO occurs check
|
2022-08-25 19:04:25 +02:00
|
|
|
return false;
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
Solution[Y] = B;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (B->getKind() == TypeKind::Var) {
|
|
|
|
return unify(B, A, Solution);
|
|
|
|
}
|
|
|
|
|
2022-08-25 19:04:25 +02:00
|
|
|
if (A->getKind() == TypeKind::Any || B->getKind() == TypeKind::Any) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2022-08-21 16:25:52 +02:00
|
|
|
if (A->getKind() == TypeKind::Arrow && B->getKind() == TypeKind::Arrow) {
|
|
|
|
auto Y = static_cast<TArrow*>(A);
|
|
|
|
auto Z = static_cast<TArrow*>(B);
|
|
|
|
if (Y->ParamTypes.size() != Z->ParamTypes.size()) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
auto Count = Y->ParamTypes.size();
|
|
|
|
for (std::size_t I = 0; I < Count; I++) {
|
|
|
|
if (!unify(Y->ParamTypes[I], Z->ParamTypes[I], Solution)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return unify(Y->ReturnType, Z->ReturnType, Solution);
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
if (A->getKind() == TypeKind::Arrow) {
|
|
|
|
auto Y = static_cast<TArrow*>(A);
|
|
|
|
if (Y->ParamTypes.empty()) {
|
|
|
|
return unify(Y->ReturnType, B, Solution);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (B->getKind() == TypeKind::Arrow) {
|
|
|
|
return unify(B, A, Solution);
|
|
|
|
}
|
|
|
|
|
2022-08-25 23:04:09 +02:00
|
|
|
if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) {
|
|
|
|
auto Y = static_cast<TTuple*>(A);
|
|
|
|
auto Z = static_cast<TTuple*>(B);
|
|
|
|
if (Y->ElementTypes.size() != Z->ElementTypes.size()) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
auto Count = Y->ElementTypes.size();
|
|
|
|
bool Success = true;
|
|
|
|
for (size_t I = 0; I < Count; I++) {
|
|
|
|
if (!unify(Y->ElementTypes[I], Z->ElementTypes[I], Solution)) {
|
|
|
|
Success = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return Success;
|
|
|
|
}
|
|
|
|
|
2022-08-25 19:04:25 +02:00
|
|
|
if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) {
|
2022-08-21 16:25:52 +02:00
|
|
|
auto Y = static_cast<TCon*>(A);
|
|
|
|
auto Z = static_cast<TCon*>(B);
|
|
|
|
if (Y->Id != Z->Id) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
ZEN_ASSERT(Y->Args.size() == Z->Args.size());
|
|
|
|
auto Count = Y->Args.size();
|
|
|
|
for (std::size_t I = 0; I < Count; I++) {
|
|
|
|
if (!unify(Y->Args[I], Z->Args[I], Solution)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2022-08-26 22:10:18 +02:00
|
|
|
InferContext* Checker::lookupCall(Node* Source, SymbolPath Path) {
|
|
|
|
auto Def = Source->getScope()->lookup(Path);
|
|
|
|
auto Match = CallGraph.find(Def);
|
|
|
|
if (Match == CallGraph.end()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return Match->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
Type* Checker::getType(Node *Node, const TVSub &Solution) {
|
|
|
|
auto Match = Mapping.find(Node);
|
|
|
|
if (Match == Mapping.end()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return Match->second->substitute(Solution);
|
|
|
|
}
|
|
|
|
|
2022-08-21 16:25:52 +02:00
|
|
|
}
|
|
|
|
|