Add partial support for recursive functions

This commit is contained in:
Sam Vervaeck 2022-08-26 22:10:18 +02:00
parent 43301a3a44
commit cfb596f8e1
6 changed files with 499 additions and 132 deletions

View file

@ -61,6 +61,7 @@ if (BOLT_ENABLE_TESTS)
add_executable( add_executable(
alltests alltests
src/TestText.cc src/TestText.cc
src/TestChecker.cc
) )
target_link_libraries( target_link_libraries(
alltests alltests

View file

@ -1,7 +1,9 @@
#ifndef BOLT_CST_HPP #ifndef BOLT_CST_HPP
#define BOLT_CST_HPP #define BOLT_CST_HPP
#include <istream>
#include <iterator> #include <iterator>
#include <unordered_map>
#include <vector> #include <vector>
#include "bolt/Text.hpp" #include "bolt/Text.hpp"
@ -10,6 +12,10 @@
namespace bolt { namespace bolt {
class Token;
class SourceFile;
class Scope;
enum class NodeType { enum class NodeType {
Equals, Equals,
Colon, Colon,
@ -47,6 +53,7 @@ namespace bolt {
ArrowTypeExpression, ArrowTypeExpression,
BindPattern, BindPattern,
ReferenceExpression, ReferenceExpression,
NestedExpression,
ConstantExpression, ConstantExpression,
CallExpression, CallExpression,
InfixExpression, InfixExpression,
@ -65,8 +72,10 @@ namespace bolt {
SourceFile, SourceFile,
}; };
class Token; struct SymbolPath {
class SourceFile; std::vector<ByteString> Modules;
ByteString Name;
};
class Node { class Node {
@ -101,10 +110,28 @@ namespace bolt {
SourceFile* getSourceFile(); SourceFile* getSourceFile();
virtual Scope* getScope();
virtual ~Node(); virtual ~Node();
}; };
class Scope {
Node* Source;
std::unordered_map<ByteString, Node*> Mapping;
public:
inline Scope(Node* Source):
Source(Source) {}
Node* lookup(SymbolPath Path);
Scope* getParentScope();
};
class Token : public Node { class Token : public Node {
TextLoc StartLoc; TextLoc StartLoc;
@ -551,6 +578,8 @@ namespace bolt {
void setParents() override; void setParents() override;
SymbolPath getSymbolPath() const;
~QualifiedName(); ~QualifiedName();
}; };
@ -661,6 +690,31 @@ namespace bolt {
}; };
class NestedExpression : public Expression {
public:
LParen* LParen;
Expression* Inner;
RParen* RParen;
inline NestedExpression(
class LParen* LParen,
Expression* Inner,
class RParen* RParen
): Expression(NodeType::NestedExpression),
LParen(LParen),
Inner(Inner),
RParen(RParen) {}
void setParents() override;
Token* getFirstToken() override;
Token* getLastToken() override;
~NestedExpression();
};
class ConstantExpression : public Expression { class ConstantExpression : public Expression {
public: public:
@ -931,9 +985,18 @@ namespace bolt {
}; };
class Type;
class InferContext;
class LetDeclaration : public Node { class LetDeclaration : public Node {
Scope TheScope;
public: public:
InferContext* Ctx;
class Type* Ty;
PubKeyword* PubKeyword; PubKeyword* PubKeyword;
LetKeyword* LetKeyword; LetKeyword* LetKeyword;
MutKeyword* MutKeyword; MutKeyword* MutKeyword;
@ -951,6 +1014,7 @@ namespace bolt {
class TypeAssert* TypeAssert, class TypeAssert* TypeAssert,
LetBody* Body LetBody* Body
): Node(NodeType::LetDeclaration), ): Node(NodeType::LetDeclaration),
TheScope(this),
PubKeyword(PubKeyword), PubKeyword(PubKeyword),
LetKeyword(LetKeywod), LetKeyword(LetKeywod),
MutKeyword(MutKeyword), MutKeyword(MutKeyword),
@ -959,6 +1023,10 @@ namespace bolt {
TypeAssert(TypeAssert), TypeAssert(TypeAssert),
Body(Body) {} Body(Body) {}
inline Scope* getScope() override {
return &TheScope;
}
void setParents() override; void setParents() override;
Token* getFirstToken() override; Token* getFirstToken() override;
@ -1025,6 +1093,9 @@ namespace bolt {
}; };
class SourceFile : public Node { class SourceFile : public Node {
Scope TheScope;
public: public:
TextFile& File; TextFile& File;
@ -1032,7 +1103,7 @@ namespace bolt {
std::vector<Node*> Elements; std::vector<Node*> Elements;
SourceFile(TextFile& File, std::vector<Node*> Elements): SourceFile(TextFile& File, std::vector<Node*> Elements):
Node(NodeType::SourceFile), File(File), Elements(Elements) {} Node(NodeType::SourceFile), TheScope(this), File(File), Elements(Elements) {}
inline TextFile& getTextFile() { inline TextFile& getTextFile() {
return File; return File;
@ -1043,6 +1114,10 @@ namespace bolt {
Token* getFirstToken() override; Token* getFirstToken() override;
Token* getLastToken() override; Token* getLastToken() override;
inline Scope* getScope() override {
return &TheScope;
}
~SourceFile(); ~SourceFile();
}; };

View file

@ -4,8 +4,9 @@
#include "zen/config.hpp" #include "zen/config.hpp"
#include "bolt/ByteString.hpp" #include "bolt/ByteString.hpp"
#include "bolt/CST.hpp"
#include <stack> #include <istream>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
@ -15,10 +16,6 @@ namespace bolt {
class DiagnosticEngine; class DiagnosticEngine;
class Node; class Node;
class Expression;
class TypeExpression;
class Pattern;
class SourceFile;
class Type; class Type;
class TVar; class TVar;
@ -47,6 +44,14 @@ namespace bolt {
bool hasTypeVar(const TVar* TV); bool hasTypeVar(const TVar* TV);
void addTypeVars(TVSet& TVs);
inline TVSet getTypeVars() {
TVSet Out;
addTypeVars(Out);
return Out;
}
Type* substitute(const TVSub& Sub); Type* substitute(const TVSub& Sub);
inline TypeKind getKind() const noexcept { inline TypeKind getKind() const noexcept {
@ -273,14 +278,6 @@ namespace bolt {
inline InferContext(InferContext* Parent = nullptr): inline InferContext(InferContext* Parent = nullptr):
Parent(Parent), ReturnType(nullptr) {} Parent(Parent), ReturnType(nullptr) {}
void addConstraint(Constraint* C);
void addBinding(ByteString Name, Scheme Scm);
Type* lookupMono(ByteString Name);
Scheme* lookup(ByteString Name);
}; };
class Checker { class Checker {
@ -290,36 +287,66 @@ namespace bolt {
size_t nextConTypeId = 0; size_t nextConTypeId = 0;
size_t nextTypeVarId = 0; size_t nextTypeVarId = 0;
std::unordered_map<Node*, Type*> Mapping;
std::unordered_map<Node*, InferContext*> CallGraph;
Type* BoolType; Type* BoolType;
Type* IntType; Type* IntType;
Type* StringType; Type* StringType;
std::stack<InferContext> Contexts; std::vector<InferContext*> Contexts;
void addConstraint(Constraint* Constraint); void addConstraint(Constraint* Constraint);
Type* inferExpression(Expression* Expression, InferContext& Ctx); void forwardDeclare(Node* Node);
Type* inferTypeExpression(TypeExpression* TE, InferContext& Ctx);
void inferBindings(Pattern* Pattern, Type* T, InferContext& Ctx, ConstraintSet& Constraints, TVSet& Tvs); Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE);
void infer(Node* node, InferContext& Ctx); void inferBindings(Pattern* Pattern, Type* T, ConstraintSet& Constraints, TVSet& Tvs);
void infer(Node* node);
TCon* createPrimConType(); TCon* createPrimConType();
TVar* createTypeVar(InferContext& Ctx); TVar* createTypeVar();
Type* instantiate(Scheme& S, InferContext& Ctx, Node* Source); void addBinding(ByteString Name, Scheme Scm);
Type* lookupMono(ByteString Name);
InferContext* lookupCall(Node* Source, SymbolPath Path);
Type* getReturnType();
Scheme* lookup(ByteString Name);
Type* instantiate(Scheme& S, Node* Source);
bool unify(Type* A, Type* B, TVSub& Solution); bool unify(Type* A, Type* B, TVSub& Solution);
void solve(Constraint* Constraint); void solve(Constraint* Constraint, TVSub& Solution);
public: public:
Checker(DiagnosticEngine& DE); Checker(DiagnosticEngine& DE);
void check(SourceFile* SF); TVSub check(SourceFile* SF);
inline Type* getBoolType() {
return BoolType;
}
inline Type* getStringType() {
return StringType;
}
inline Type* getIntType() {
return IntType;
}
Type* getType(Node* Node, const TVSub& Solution);
}; };

View file

@ -5,6 +5,25 @@
namespace bolt { namespace bolt {
Node* Scope::lookup(SymbolPath Path) {
auto Curr = this;
do {
auto Match = Curr->Mapping.find(Path.Name);
if (Match != Curr->Mapping.end()) {
return Match->second;
}
Curr = Curr->getParentScope();
} while (Curr != nullptr);
return nullptr;
}
Scope* Scope::getParentScope() {
if (Source->Parent == nullptr) {
return nullptr;
}
return Source->Parent->getScope();
}
SourceFile* Node::getSourceFile() { SourceFile* Node::getSourceFile() {
auto CurrNode = this; auto CurrNode = this;
for (;;) { for (;;) {
@ -23,6 +42,10 @@ namespace bolt {
}; };
} }
Scope* Node::getScope() {
return this->Parent->getScope();
}
TextLoc Token::getEndLoc() { TextLoc Token::getEndLoc() {
auto EndLoc = StartLoc; auto EndLoc = StartLoc;
EndLoc.advance(getText()); EndLoc.advance(getText());
@ -61,6 +84,13 @@ namespace bolt {
Name->Parent = this; Name->Parent = this;
} }
void NestedExpression::setParents() {
LParen->Parent = this;
Inner->Parent = this;
Inner->setParents();
RParen->Parent = this;
}
void ConstantExpression::setParents() { void ConstantExpression::setParents() {
Token->Parent = this; Token->Parent = this;
} }
@ -330,6 +360,12 @@ namespace bolt {
Name->unref(); Name->unref();
} }
NestedExpression::~NestedExpression() {
LParen->unref();
Inner->unref();
RParen->unref();
}
ConstantExpression::~ConstantExpression() { ConstantExpression::~ConstantExpression() {
Token->unref(); Token->unref();
} }
@ -493,6 +529,14 @@ namespace bolt {
return Name->getLastToken(); return Name->getLastToken();
} }
Token* NestedExpression::getFirstToken() {
return LParen;
}
Token* NestedExpression::getLastToken() {
return RParen;
}
Token* ConstantExpression::getFirstToken() { Token* ConstantExpression::getFirstToken() {
return Token; return Token;
} }
@ -786,5 +830,13 @@ namespace bolt {
return ".."; return "..";
} }
SymbolPath QualifiedName::getSymbolPath() const {
std::vector<ByteString> ModuleNames;
for (auto Ident: ModulePath) {
ModuleNames.push_back(Ident->Text);
}
return SymbolPath { ModuleNames, Name->Text };
}
} }

View file

@ -11,6 +11,41 @@ namespace bolt {
std::string describe(const Type* Ty); std::string describe(const Type* Ty);
void Type::addTypeVars(TVSet& TVs) {
switch (Kind) {
case TypeKind::Var:
TVs.emplace(static_cast<TVar*>(this));
break;
case TypeKind::Arrow:
{
auto Y = static_cast<TArrow*>(this);
for (auto Ty: Y->ParamTypes) {
Ty->addTypeVars(TVs);
}
Y->ReturnType->addTypeVars(TVs);
break;
}
case TypeKind::Con:
{
auto Y = static_cast<TCon*>(this);
for (auto Ty: Y->Args) {
Ty->addTypeVars(TVs);
}
break;
}
case TypeKind::Tuple:
{
auto Y = static_cast<TTuple*>(this);
for (auto Ty: Y->ElementTypes) {
Ty->addTypeVars(TVs);
}
break;
}
case TypeKind::Any:
break;
}
}
bool Type::hasTypeVar(const TVar* TV) { bool Type::hasTypeVar(const TVar* TV) {
switch (Kind) { switch (Kind) {
case TypeKind::Var: case TypeKind::Var:
@ -61,32 +96,50 @@ namespace bolt {
case TypeKind::Arrow: case TypeKind::Arrow:
{ {
auto Y = static_cast<TArrow*>(this); auto Y = static_cast<TArrow*>(this);
bool Changed = false;
std::vector<Type*> NewParamTypes; std::vector<Type*> NewParamTypes;
for (auto Ty: Y->ParamTypes) { for (auto Ty: Y->ParamTypes) {
NewParamTypes.push_back(Ty->substitute(Sub)); auto NewParamType = Ty->substitute(Sub);
if (NewParamType != Ty) {
Changed = true;
}
NewParamTypes.push_back(NewParamType);
} }
auto NewRetTy = Y->ReturnType->substitute(Sub) ; auto NewRetTy = Y->ReturnType->substitute(Sub) ;
return new TArrow(NewParamTypes, NewRetTy); if (NewRetTy != Y->ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
} }
case TypeKind::Any: case TypeKind::Any:
return this; return this;
case TypeKind::Con: case TypeKind::Con:
{ {
auto Y = static_cast<TCon*>(this); auto Y = static_cast<TCon*>(this);
bool Changed = false;
std::vector<Type*> NewArgs; std::vector<Type*> NewArgs;
for (auto Arg: Y->Args) { for (auto Arg: Y->Args) {
NewArgs.push_back(Arg->substitute(Sub)); auto NewArg = Arg->substitute(Sub);
if (NewArg != Arg) {
Changed = true;
} }
return new TCon(Y->Id, NewArgs, Y->DisplayName); NewArgs.push_back(NewArg);
}
return Changed ? new TCon(Y->Id, NewArgs, Y->DisplayName) : this;
} }
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
auto Y = static_cast<TTuple*>(this); auto Y = static_cast<TTuple*>(this);
bool Changed = false;
std::vector<Type*> NewElementTypes; std::vector<Type*> NewElementTypes;
for (auto Ty: Y->ElementTypes) { for (auto Ty: Y->ElementTypes) {
NewElementTypes.push_back(Ty->substitute(Sub)); auto NewElementType = Ty->substitute(Sub);
if (NewElementType != Ty) {
Changed = true;
} }
return new TTuple(NewElementTypes); NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : this;
} }
} }
} }
@ -102,7 +155,7 @@ namespace bolt {
{ {
auto Y = static_cast<CMany*>(this); auto Y = static_cast<CMany*>(this);
auto NewConstraints = new ConstraintSet(); auto NewConstraints = new ConstraintSet();
for (auto Element: Y->Constraints) { for (auto Element: Y->Elements) {
NewConstraints->push_back(Element->substitute(Sub)); NewConstraints->push_back(Element->substitute(Sub));
} }
return new CMany(*NewConstraints); return new CMany(*NewConstraints);
@ -112,21 +165,25 @@ namespace bolt {
} }
} }
Scheme* InferContext::lookup(ByteString Name) { Checker::Checker(DiagnosticEngine& DE):
InferContext* Curr = this; DE(DE) {
for (;;) { BoolType = new TCon(nextConTypeId++, {}, "Bool");
IntType = new TCon(nextConTypeId++, {}, "Int");
StringType = new TCon(nextConTypeId++, {}, "String");
}
Scheme* Checker::lookup(ByteString Name) {
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
auto Curr = *Iter;
auto Match = Curr->Env.find(Name); auto Match = Curr->Env.find(Name);
if (Match != Curr->Env.end()) { if (Match != Curr->Env.end()) {
return &Match->second; return &Match->second;
} }
Curr = Curr->Parent; }
if (Curr == nullptr) {
return nullptr; return nullptr;
} }
}
}
Type* InferContext::lookupMono(ByteString Name) { Type* Checker::lookupMono(ByteString Name) {
auto Scm = lookup(Name); auto Scm = lookup(Name);
if (Scm == nullptr) { if (Scm == nullptr) {
return nullptr; return nullptr;
@ -136,22 +193,126 @@ namespace bolt {
return F.Type; return F.Type;
} }
void InferContext::addBinding(ByteString Name, Scheme S) { void Checker::addBinding(ByteString Name, Scheme S) {
Env.emplace(Name, S); Contexts.back()->Env.emplace(Name, S);
} }
void InferContext::addConstraint(Constraint *C) { Type* Checker::getReturnType() {
Constraints.push_back(C); auto Ty = Contexts.back()->ReturnType;
ZEN_ASSERT(Ty != nullptr);
return Ty;
} }
Checker::Checker(DiagnosticEngine& DE): static bool hasTypeVar(TVSet& Set, Type* Type) {
DE(DE) { for (auto TV: Type->getTypeVars()) {
BoolType = new TCon(nextConTypeId++, {}, "Bool"); if (Set.count(TV)) {
IntType = new TCon(nextConTypeId++, {}, "Int"); return true;
StringType = new TCon(nextConTypeId++, {}, "String"); }
}
return false;
} }
void Checker::infer(Node* X, InferContext& Ctx) { void Checker::addConstraint(Constraint* Constraint) {
switch (Constraint->getKind()) {
case ConstraintKind::Equal:
{
auto Y = static_cast<CEqual*>(Constraint);
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
auto& Ctx = **Iter;
if (hasTypeVar(Ctx.TVs, Y->Left) || hasTypeVar(Ctx.TVs, Y->Right)) {
Ctx.Constraints.push_back(Constraint);
return;
}
}
Contexts.front()->Constraints.push_back(Constraint);
//auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth);
//ZEN_ASSERT(I < Contexts.size());
//auto Ctx = Contexts[I];
//Ctx->Constraints.push_back(Constraint);
break;
}
case ConstraintKind::Many:
{
auto Y = static_cast<CMany*>(Constraint);
for (auto Element: Y->Elements) {
addConstraint(Element);
}
break;
}
case ConstraintKind::Empty:
break;
}
}
void Checker::forwardDeclare(Node* X) {
switch (X->Type) {
case NodeType::ExpressionStatement:
case NodeType::ReturnStatement:
case NodeType::IfStatement:
break;
case NodeType::SourceFile:
{
auto Y = static_cast<SourceFile*>(X);
for (auto Element: Y->Elements) {
forwardDeclare(Element) ;
}
break;
}
case NodeType::LetDeclaration:
{
auto Y = static_cast<LetDeclaration*>(X);
auto NewCtx = new InferContext();
Y->Ctx = NewCtx;
std::cerr << Y << std::endl;
Contexts.push_back(NewCtx);
Type* Ty;
if (Y->TypeAssert) {
Ty = inferTypeExpression(Y->TypeAssert->TypeExpression);
} else {
Ty = createTypeVar();
}
Y->Ty = Ty;
if (Y->Body) {
switch (Y->Body->Type) {
case NodeType::LetExprBody:
break;
case NodeType::LetBlockBody:
{
auto Z = static_cast<LetBlockBody*>(Y->Body);
for (auto Element: Z->Elements) {
forwardDeclare(Element);
}
break;
}
default:
ZEN_UNREACHABLE
}
}
Contexts.pop_back();
inferBindings(Y->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs);
break;
}
default:
ZEN_UNREACHABLE
}
}
void Checker::infer(Node* X) {
switch (X->Type) { switch (X->Type) {
@ -159,7 +320,7 @@ namespace bolt {
{ {
auto Y = static_cast<SourceFile*>(X); auto Y = static_cast<SourceFile*>(X);
for (auto Element: Y->Elements) { for (auto Element: Y->Elements) {
infer(Element, Ctx); infer(Element);
} }
break; break;
} }
@ -169,10 +330,10 @@ namespace bolt {
auto Y = static_cast<IfStatement*>(X); auto Y = static_cast<IfStatement*>(X);
for (auto Part: Y->Parts) { for (auto Part: Y->Parts) {
if (Part->Test != nullptr) { if (Part->Test != nullptr) {
Ctx.addConstraint(new CEqual { BoolType, inferExpression(Part->Test, Ctx), Part->Test }); addConstraint(new CEqual { BoolType, inferExpression(Part->Test), Part->Test });
} }
for (auto Element: Part->Elements) { for (auto Element: Part->Elements) {
infer(Element, Ctx); infer(Element);
} }
} }
break; break;
@ -182,24 +343,18 @@ namespace bolt {
{ {
auto Y = static_cast<LetDeclaration*>(X); auto Y = static_cast<LetDeclaration*>(X);
auto NewCtx = new InferContext { Ctx }; auto NewCtx = Y->Ctx;
Contexts.push_back(NewCtx);
Type* Ty;
if (Y->TypeAssert) {
Ty = inferTypeExpression(Y->TypeAssert->TypeExpression, *NewCtx);
} else {
Ty = createTypeVar(*NewCtx);
}
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
Type* RetType; Type* RetType;
for (auto Param: Y->Params) { for (auto Param: Y->Params) {
// TODO incorporate Param->TypeAssert or make it a kind of pattern // TODO incorporate Param->TypeAssert or make it a kind of pattern
TVar* TV = createTypeVar(*NewCtx); TVar* TV = createTypeVar();
TVSet NoTVs; TVSet NoTVs;
ConstraintSet NoConstraints; ConstraintSet NoConstraints;
inferBindings(Param->Pattern, TV, *NewCtx, NoConstraints, NoTVs); inferBindings(Param->Pattern, TV, NoConstraints, NoTVs);
ParamTypes.push_back(TV); ParamTypes.push_back(TV);
} }
@ -208,16 +363,16 @@ namespace bolt {
case NodeType::LetExprBody: case NodeType::LetExprBody:
{ {
auto Z = static_cast<LetExprBody*>(Y->Body); auto Z = static_cast<LetExprBody*>(Y->Body);
RetType = inferExpression(Z->Expression, *NewCtx); RetType = inferExpression(Z->Expression);
break; break;
} }
case NodeType::LetBlockBody: case NodeType::LetBlockBody:
{ {
auto Z = static_cast<LetBlockBody*>(Y->Body); auto Z = static_cast<LetBlockBody*>(Y->Body);
RetType = createTypeVar(*NewCtx); RetType = createTypeVar();
NewCtx->ReturnType = RetType; NewCtx->ReturnType = RetType;
for (auto Element: Z->Elements) { for (auto Element: Z->Elements) {
infer(Element, *NewCtx); infer(Element);
} }
break; break;
} }
@ -225,12 +380,12 @@ namespace bolt {
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
} else { } else {
RetType = createTypeVar(*NewCtx); RetType = createTypeVar();
} }
NewCtx->addConstraint(new CEqual { Ty, new TArrow(ParamTypes, RetType), X }); addConstraint(new CEqual { Y->Ty, new TArrow(ParamTypes, RetType), X });
inferBindings(Y->Pattern, Ty, Ctx, NewCtx->Constraints, NewCtx->TVs); Contexts.pop_back();
break; break;
} }
@ -240,19 +395,18 @@ namespace bolt {
auto Y = static_cast<ReturnStatement*>(X); auto Y = static_cast<ReturnStatement*>(X);
Type* ReturnType; Type* ReturnType;
if (Y->Expression) { if (Y->Expression) {
ReturnType = inferExpression(Y->Expression, Ctx); ReturnType = inferExpression(Y->Expression);
} else { } else {
ReturnType = new TTuple({}); ReturnType = new TTuple({});
} }
ZEN_ASSERT(Ctx.ReturnType != nullptr); addConstraint(new CEqual { ReturnType, getReturnType(), X });
Ctx.addConstraint(new CEqual { ReturnType, Ctx.ReturnType, X });
break; break;
} }
case NodeType::ExpressionStatement: case NodeType::ExpressionStatement:
{ {
auto Y = static_cast<ExpressionStatement*>(X); auto Y = static_cast<ExpressionStatement*>(X);
inferExpression(Y->Expression, Ctx); inferExpression(Y->Expression);
break; break;
} }
@ -263,13 +417,13 @@ namespace bolt {
} }
TVar* Checker::createTypeVar(InferContext& Ctx) { TVar* Checker::createTypeVar() {
auto TV = new TVar(nextTypeVarId++); auto TV = new TVar(nextTypeVarId++);
Ctx.TVs.emplace(TV); Contexts.back()->TVs.emplace(TV);
return TV; return TV;
} }
Type* Checker::instantiate(Scheme& S, InferContext& Ctx, Node* Source) { Type* Checker::instantiate(Scheme& S, Node* Source) {
switch (S.getKind()) { switch (S.getKind()) {
@ -278,13 +432,9 @@ namespace bolt {
auto& F = S.as<Forall>(); auto& F = S.as<Forall>();
TVSub Sub; TVSub Sub;
if (F.TVs) {
for (auto TV: *F.TVs) { for (auto TV: *F.TVs) {
Sub[TV] = createTypeVar(Ctx); Sub[TV] = createTypeVar();
} }
}
if (F.Constraints) {
for (auto Constraint: *F.Constraints) { for (auto Constraint: *F.Constraints) {
@ -296,29 +446,32 @@ namespace bolt {
static_cast<CEqual *>(NewConstraint)->Source = Source; static_cast<CEqual *>(NewConstraint)->Source = Source;
} }
Ctx.addConstraint(NewConstraint); addConstraint(NewConstraint);
}
} }
return F.Type->substitute(Sub); // FIXME substitute should always clone if we set MaxDepth
auto NewType = F.Type->substitute(Sub);
//NewType->MaxDepth = std::max(static_cast<unsigned>(Contexts.size()-1), F.Type->MaxDepth);
return NewType;
} }
} }
} }
Type* Checker::inferTypeExpression(TypeExpression* X, InferContext& Ctx) { Type* Checker::inferTypeExpression(TypeExpression* X) {
switch (X->Type) { switch (X->Type) {
case NodeType::ReferenceTypeExpression: case NodeType::ReferenceTypeExpression:
{ {
auto Y = static_cast<ReferenceTypeExpression*>(X); auto Y = static_cast<ReferenceTypeExpression*>(X);
auto Ty = Ctx.lookupMono(Y->Name->Name->Text); auto Ty = lookupMono(Y->Name->Name->Text);
if (Ty == nullptr) { if (Ty == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name->Name); DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name->Name);
return new TAny(); return new TAny();
} }
Mapping[X] = Ty;
return Ty; return Ty;
} }
@ -327,10 +480,12 @@ namespace bolt {
auto Y = static_cast<ArrowTypeExpression*>(X); auto Y = static_cast<ArrowTypeExpression*>(X);
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
for (auto ParamType: Y->ParamTypes) { for (auto ParamType: Y->ParamTypes) {
ParamTypes.push_back(inferTypeExpression(ParamType, Ctx)); ParamTypes.push_back(inferTypeExpression(ParamType));
} }
auto ReturnType = inferTypeExpression(Y->ReturnType, Ctx); auto ReturnType = inferTypeExpression(Y->ReturnType);
return new TArrow(ParamTypes, ReturnType); auto Ty = new TArrow(ParamTypes, ReturnType);
Mapping[X] = Ty;
return Ty;
} }
default: default:
@ -339,7 +494,7 @@ namespace bolt {
} }
} }
Type* Checker::inferExpression(Expression* X, InferContext& Ctx) { Type* Checker::inferExpression(Expression* X) {
switch (X->Type) { switch (X->Type) {
@ -349,15 +504,16 @@ namespace bolt {
Type* Ty = nullptr; Type* Ty = nullptr;
switch (Y->Token->Type) { switch (Y->Token->Type) {
case NodeType::IntegerLiteral: case NodeType::IntegerLiteral:
Ty = Ctx.lookupMono("Int"); Ty = lookupMono("Int");
break; break;
case NodeType::StringLiteral: case NodeType::StringLiteral:
Ty = Ctx.lookupMono("String"); Ty = lookupMono("String");
break; break;
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
ZEN_ASSERT(Ty != nullptr); ZEN_ASSERT(Ty != nullptr);
Mapping[X] = Ty;
return Ty; return Ty;
} }
@ -365,44 +521,58 @@ namespace bolt {
{ {
auto Y = static_cast<ReferenceExpression*>(X); auto Y = static_cast<ReferenceExpression*>(X);
ZEN_ASSERT(Y->Name->ModulePath.empty()); ZEN_ASSERT(Y->Name->ModulePath.empty());
auto Scm = Ctx.lookup(Y->Name->Name->Text); auto Ctx = lookupCall(Y, Y->Name->getSymbolPath());
if (Ctx) {
return Ctx->ReturnType;
}
auto Scm = lookup(Y->Name->Name->Text);
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name); DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name);
return new TAny(); return new TAny();
} }
return instantiate(*Scm, Ctx, X); auto Ty = instantiate(*Scm, X);
Mapping[X] = Ty;
return Ty;
} }
case NodeType::CallExpression: case NodeType::CallExpression:
{ {
auto Y = static_cast<CallExpression*>(X); auto Y = static_cast<CallExpression*>(X);
auto OpTy = inferExpression(Y->Function, Ctx); auto OpTy = inferExpression(Y->Function);
auto RetType = createTypeVar(Ctx); auto RetType = createTypeVar();
std::vector<Type*> ArgTypes; std::vector<Type*> ArgTypes;
for (auto Arg: Y->Args) { for (auto Arg: Y->Args) {
ArgTypes.push_back(inferExpression(Arg, Ctx)); ArgTypes.push_back(inferExpression(Arg));
} }
Ctx.addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X }); addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X });
Mapping[X] = RetType;
return RetType; return RetType;
} }
case NodeType::InfixExpression: case NodeType::InfixExpression:
{ {
auto Y = static_cast<InfixExpression*>(X); auto Y = static_cast<InfixExpression*>(X);
auto Scm = Ctx.lookup(Y->Operator->getText()); auto Scm = lookup(Y->Operator->getText());
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator); DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator);
return new TAny(); return new TAny();
} }
auto OpTy = instantiate(*Scm, Ctx, Y->Operator); auto OpTy = instantiate(*Scm, Y->Operator);
auto RetTy = createTypeVar(Ctx); auto RetTy = createTypeVar();
std::vector<Type*> ArgTys; std::vector<Type*> ArgTys;
ArgTys.push_back(inferExpression(Y->LHS, Ctx)); ArgTys.push_back(inferExpression(Y->LHS));
ArgTys.push_back(inferExpression(Y->RHS, Ctx)); ArgTys.push_back(inferExpression(Y->RHS));
Ctx.addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X }); addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X });
Mapping[X] = RetTy;
return RetTy; return RetTy;
} }
case NodeType::NestedExpression:
{
auto Y = static_cast<NestedExpression*>(X);
return inferExpression(Y->Inner);
}
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
@ -410,12 +580,12 @@ namespace bolt {
} }
void Checker::inferBindings(Pattern* Pattern, Type* Type, InferContext& Ctx, ConstraintSet& Constraints, TVSet& TVs) { void Checker::inferBindings(Pattern* Pattern, Type* Type, ConstraintSet& Constraints, TVSet& TVs) {
switch (Pattern->Type) { switch (Pattern->Type) {
case NodeType::BindPattern: case NodeType::BindPattern:
Ctx.addBinding(static_cast<BindPattern*>(Pattern)->Name->Text, Forall(TVs, Constraints, Type)); addBinding(static_cast<BindPattern*>(Pattern)->Name->Text, Forall(TVs, Constraints, Type));
break; break;
default: default:
@ -424,26 +594,33 @@ namespace bolt {
} }
} }
void Checker::check(SourceFile *SF) { TVSub Checker::check(SourceFile *SF) {
InferContext Toplevel; Contexts.push_back(new InferContext {});
Toplevel.addBinding("String", Forall(StringType)); ConstraintSet NoConstraints;
Toplevel.addBinding("Int", Forall(IntType)); addBinding("String", Forall(StringType));
Toplevel.addBinding("Bool", Forall(BoolType)); addBinding("Int", Forall(IntType));
Toplevel.addBinding("True", Forall(BoolType)); addBinding("Bool", Forall(BoolType));
Toplevel.addBinding("False", Forall(BoolType)); addBinding("True", Forall(BoolType));
Toplevel.addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType))); addBinding("False", Forall(BoolType));
Toplevel.addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType))); auto A = createTypeVar();
Toplevel.addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType))); TVSet SingleA { A };
Toplevel.addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType))); addBinding("==", Forall(SingleA, NoConstraints, new TArrow({ A, A }, BoolType)));
infer(SF, Toplevel); addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType)));
solve(new CMany(Toplevel.Constraints)); addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType)));
addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType)));
addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType)));
forwardDeclare(SF);
infer(SF);
TVSub Solution;
solve(new CMany(Contexts.front()->Constraints), Solution);
Contexts.pop_back();
return Solution;
} }
void Checker::solve(Constraint* Constraint) { void Checker::solve(Constraint* Constraint, TVSub& Solution) {
std::stack<class Constraint*> Queue; std::stack<class Constraint*> Queue;
Queue.push(Constraint); Queue.push(Constraint);
TVSub Solution;
while (!Queue.empty()) { while (!Queue.empty()) {
@ -459,7 +636,7 @@ namespace bolt {
case ConstraintKind::Many: case ConstraintKind::Many:
{ {
auto Y = static_cast<CMany*>(Constraint); auto Y = static_cast<CMany*>(Constraint);
for (auto Constraint: Y->Constraints) { for (auto Constraint: Y->Elements) {
Queue.push(Constraint); Queue.push(Constraint);
} }
break; break;
@ -468,7 +645,7 @@ namespace bolt {
case ConstraintKind::Equal: case ConstraintKind::Equal:
{ {
auto Y = static_cast<CEqual*>(Constraint); auto Y = static_cast<CEqual*>(Constraint);
std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl; //std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl;
if (!unify(Y->Left, Y->Right, Solution)) { if (!unify(Y->Left, Y->Right, Solution)) {
DE.add<UnificationErrorDiagnostic>(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source); DE.add<UnificationErrorDiagnostic>(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source);
} }
@ -530,6 +707,17 @@ namespace bolt {
return unify(Y->ReturnType, Z->ReturnType, Solution); return unify(Y->ReturnType, Z->ReturnType, Solution);
} }
if (A->getKind() == TypeKind::Arrow) {
auto Y = static_cast<TArrow*>(A);
if (Y->ParamTypes.empty()) {
return unify(Y->ReturnType, B, Solution);
}
}
if (B->getKind() == TypeKind::Arrow) {
return unify(B, A, Solution);
}
if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) { if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) {
auto Y = static_cast<TTuple*>(A); auto Y = static_cast<TTuple*>(A);
auto Z = static_cast<TTuple*>(B); auto Z = static_cast<TTuple*>(B);
@ -565,5 +753,22 @@ namespace bolt {
return false; return false;
} }
InferContext* Checker::lookupCall(Node* Source, SymbolPath Path) {
auto Def = Source->getScope()->lookup(Path);
auto Match = CallGraph.find(Def);
if (Match == CallGraph.end()) {
return nullptr;
}
return Match->second;
}
Type* Checker::getType(Node *Node, const TVSub &Solution) {
auto Match = Mapping.find(Node);
if (Match == Mapping.end()) {
return nullptr;
}
return Match->second->substitute(Solution);
}
} }

View file

@ -148,6 +148,13 @@ namespace bolt {
auto Name = parseQualifiedName(); auto Name = parseQualifiedName();
return new ReferenceExpression(Name); return new ReferenceExpression(Name);
} }
case NodeType::LParen:
{
Tokens.get();
auto E = parseExpression();
auto T2 = static_cast<RParen*>(expectToken(NodeType::RParen));
return new NestedExpression(static_cast<LParen*>(T0), E, T2);
}
case NodeType::IntegerLiteral: case NodeType::IntegerLiteral:
case NodeType::StringLiteral: case NodeType::StringLiteral:
Tokens.get(); Tokens.get();
@ -162,7 +169,7 @@ namespace bolt {
std::vector<Expression*> Args; std::vector<Expression*> Args;
for (;;) { for (;;) {
auto T1 = Tokens.peek(); auto T1 = Tokens.peek();
if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) { if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::RParen || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) {
break; break;
} }
Args.push_back(parsePrimitiveExpression()); Args.push_back(parsePrimitiveExpression());