Add partial support for recursive functions
This commit is contained in:
parent
43301a3a44
commit
cfb596f8e1
6 changed files with 499 additions and 132 deletions
|
@ -61,6 +61,7 @@ if (BOLT_ENABLE_TESTS)
|
||||||
add_executable(
|
add_executable(
|
||||||
alltests
|
alltests
|
||||||
src/TestText.cc
|
src/TestText.cc
|
||||||
|
src/TestChecker.cc
|
||||||
)
|
)
|
||||||
target_link_libraries(
|
target_link_libraries(
|
||||||
alltests
|
alltests
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
#ifndef BOLT_CST_HPP
|
#ifndef BOLT_CST_HPP
|
||||||
#define BOLT_CST_HPP
|
#define BOLT_CST_HPP
|
||||||
|
|
||||||
|
#include <istream>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "bolt/Text.hpp"
|
#include "bolt/Text.hpp"
|
||||||
|
@ -10,6 +12,10 @@
|
||||||
|
|
||||||
namespace bolt {
|
namespace bolt {
|
||||||
|
|
||||||
|
class Token;
|
||||||
|
class SourceFile;
|
||||||
|
class Scope;
|
||||||
|
|
||||||
enum class NodeType {
|
enum class NodeType {
|
||||||
Equals,
|
Equals,
|
||||||
Colon,
|
Colon,
|
||||||
|
@ -47,6 +53,7 @@ namespace bolt {
|
||||||
ArrowTypeExpression,
|
ArrowTypeExpression,
|
||||||
BindPattern,
|
BindPattern,
|
||||||
ReferenceExpression,
|
ReferenceExpression,
|
||||||
|
NestedExpression,
|
||||||
ConstantExpression,
|
ConstantExpression,
|
||||||
CallExpression,
|
CallExpression,
|
||||||
InfixExpression,
|
InfixExpression,
|
||||||
|
@ -65,8 +72,10 @@ namespace bolt {
|
||||||
SourceFile,
|
SourceFile,
|
||||||
};
|
};
|
||||||
|
|
||||||
class Token;
|
struct SymbolPath {
|
||||||
class SourceFile;
|
std::vector<ByteString> Modules;
|
||||||
|
ByteString Name;
|
||||||
|
};
|
||||||
|
|
||||||
class Node {
|
class Node {
|
||||||
|
|
||||||
|
@ -101,10 +110,28 @@ namespace bolt {
|
||||||
|
|
||||||
SourceFile* getSourceFile();
|
SourceFile* getSourceFile();
|
||||||
|
|
||||||
|
virtual Scope* getScope();
|
||||||
|
|
||||||
virtual ~Node();
|
virtual ~Node();
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Scope {
|
||||||
|
|
||||||
|
Node* Source;
|
||||||
|
std::unordered_map<ByteString, Node*> Mapping;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
inline Scope(Node* Source):
|
||||||
|
Source(Source) {}
|
||||||
|
|
||||||
|
Node* lookup(SymbolPath Path);
|
||||||
|
|
||||||
|
Scope* getParentScope();
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
class Token : public Node {
|
class Token : public Node {
|
||||||
|
|
||||||
TextLoc StartLoc;
|
TextLoc StartLoc;
|
||||||
|
@ -551,6 +578,8 @@ namespace bolt {
|
||||||
|
|
||||||
void setParents() override;
|
void setParents() override;
|
||||||
|
|
||||||
|
SymbolPath getSymbolPath() const;
|
||||||
|
|
||||||
~QualifiedName();
|
~QualifiedName();
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -661,6 +690,31 @@ namespace bolt {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class NestedExpression : public Expression {
|
||||||
|
public:
|
||||||
|
|
||||||
|
LParen* LParen;
|
||||||
|
Expression* Inner;
|
||||||
|
RParen* RParen;
|
||||||
|
|
||||||
|
inline NestedExpression(
|
||||||
|
class LParen* LParen,
|
||||||
|
Expression* Inner,
|
||||||
|
class RParen* RParen
|
||||||
|
): Expression(NodeType::NestedExpression),
|
||||||
|
LParen(LParen),
|
||||||
|
Inner(Inner),
|
||||||
|
RParen(RParen) {}
|
||||||
|
|
||||||
|
void setParents() override;
|
||||||
|
|
||||||
|
Token* getFirstToken() override;
|
||||||
|
Token* getLastToken() override;
|
||||||
|
|
||||||
|
~NestedExpression();
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
class ConstantExpression : public Expression {
|
class ConstantExpression : public Expression {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -931,9 +985,18 @@ namespace bolt {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Type;
|
||||||
|
class InferContext;
|
||||||
|
|
||||||
class LetDeclaration : public Node {
|
class LetDeclaration : public Node {
|
||||||
|
|
||||||
|
Scope TheScope;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
InferContext* Ctx;
|
||||||
|
class Type* Ty;
|
||||||
|
|
||||||
PubKeyword* PubKeyword;
|
PubKeyword* PubKeyword;
|
||||||
LetKeyword* LetKeyword;
|
LetKeyword* LetKeyword;
|
||||||
MutKeyword* MutKeyword;
|
MutKeyword* MutKeyword;
|
||||||
|
@ -951,6 +1014,7 @@ namespace bolt {
|
||||||
class TypeAssert* TypeAssert,
|
class TypeAssert* TypeAssert,
|
||||||
LetBody* Body
|
LetBody* Body
|
||||||
): Node(NodeType::LetDeclaration),
|
): Node(NodeType::LetDeclaration),
|
||||||
|
TheScope(this),
|
||||||
PubKeyword(PubKeyword),
|
PubKeyword(PubKeyword),
|
||||||
LetKeyword(LetKeywod),
|
LetKeyword(LetKeywod),
|
||||||
MutKeyword(MutKeyword),
|
MutKeyword(MutKeyword),
|
||||||
|
@ -959,6 +1023,10 @@ namespace bolt {
|
||||||
TypeAssert(TypeAssert),
|
TypeAssert(TypeAssert),
|
||||||
Body(Body) {}
|
Body(Body) {}
|
||||||
|
|
||||||
|
inline Scope* getScope() override {
|
||||||
|
return &TheScope;
|
||||||
|
}
|
||||||
|
|
||||||
void setParents() override;
|
void setParents() override;
|
||||||
|
|
||||||
Token* getFirstToken() override;
|
Token* getFirstToken() override;
|
||||||
|
@ -1025,6 +1093,9 @@ namespace bolt {
|
||||||
};
|
};
|
||||||
|
|
||||||
class SourceFile : public Node {
|
class SourceFile : public Node {
|
||||||
|
|
||||||
|
Scope TheScope;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
TextFile& File;
|
TextFile& File;
|
||||||
|
@ -1032,7 +1103,7 @@ namespace bolt {
|
||||||
std::vector<Node*> Elements;
|
std::vector<Node*> Elements;
|
||||||
|
|
||||||
SourceFile(TextFile& File, std::vector<Node*> Elements):
|
SourceFile(TextFile& File, std::vector<Node*> Elements):
|
||||||
Node(NodeType::SourceFile), File(File), Elements(Elements) {}
|
Node(NodeType::SourceFile), TheScope(this), File(File), Elements(Elements) {}
|
||||||
|
|
||||||
inline TextFile& getTextFile() {
|
inline TextFile& getTextFile() {
|
||||||
return File;
|
return File;
|
||||||
|
@ -1043,6 +1114,10 @@ namespace bolt {
|
||||||
Token* getFirstToken() override;
|
Token* getFirstToken() override;
|
||||||
Token* getLastToken() override;
|
Token* getLastToken() override;
|
||||||
|
|
||||||
|
inline Scope* getScope() override {
|
||||||
|
return &TheScope;
|
||||||
|
}
|
||||||
|
|
||||||
~SourceFile();
|
~SourceFile();
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
|
@ -4,8 +4,9 @@
|
||||||
#include "zen/config.hpp"
|
#include "zen/config.hpp"
|
||||||
|
|
||||||
#include "bolt/ByteString.hpp"
|
#include "bolt/ByteString.hpp"
|
||||||
|
#include "bolt/CST.hpp"
|
||||||
|
|
||||||
#include <stack>
|
#include <istream>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -15,10 +16,6 @@ namespace bolt {
|
||||||
|
|
||||||
class DiagnosticEngine;
|
class DiagnosticEngine;
|
||||||
class Node;
|
class Node;
|
||||||
class Expression;
|
|
||||||
class TypeExpression;
|
|
||||||
class Pattern;
|
|
||||||
class SourceFile;
|
|
||||||
|
|
||||||
class Type;
|
class Type;
|
||||||
class TVar;
|
class TVar;
|
||||||
|
@ -47,6 +44,14 @@ namespace bolt {
|
||||||
|
|
||||||
bool hasTypeVar(const TVar* TV);
|
bool hasTypeVar(const TVar* TV);
|
||||||
|
|
||||||
|
void addTypeVars(TVSet& TVs);
|
||||||
|
|
||||||
|
inline TVSet getTypeVars() {
|
||||||
|
TVSet Out;
|
||||||
|
addTypeVars(Out);
|
||||||
|
return Out;
|
||||||
|
}
|
||||||
|
|
||||||
Type* substitute(const TVSub& Sub);
|
Type* substitute(const TVSub& Sub);
|
||||||
|
|
||||||
inline TypeKind getKind() const noexcept {
|
inline TypeKind getKind() const noexcept {
|
||||||
|
@ -273,14 +278,6 @@ namespace bolt {
|
||||||
inline InferContext(InferContext* Parent = nullptr):
|
inline InferContext(InferContext* Parent = nullptr):
|
||||||
Parent(Parent), ReturnType(nullptr) {}
|
Parent(Parent), ReturnType(nullptr) {}
|
||||||
|
|
||||||
void addConstraint(Constraint* C);
|
|
||||||
|
|
||||||
void addBinding(ByteString Name, Scheme Scm);
|
|
||||||
|
|
||||||
Type* lookupMono(ByteString Name);
|
|
||||||
|
|
||||||
Scheme* lookup(ByteString Name);
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class Checker {
|
class Checker {
|
||||||
|
@ -290,36 +287,66 @@ namespace bolt {
|
||||||
size_t nextConTypeId = 0;
|
size_t nextConTypeId = 0;
|
||||||
size_t nextTypeVarId = 0;
|
size_t nextTypeVarId = 0;
|
||||||
|
|
||||||
|
std::unordered_map<Node*, Type*> Mapping;
|
||||||
|
|
||||||
|
std::unordered_map<Node*, InferContext*> CallGraph;
|
||||||
|
|
||||||
Type* BoolType;
|
Type* BoolType;
|
||||||
Type* IntType;
|
Type* IntType;
|
||||||
Type* StringType;
|
Type* StringType;
|
||||||
|
|
||||||
std::stack<InferContext> Contexts;
|
std::vector<InferContext*> Contexts;
|
||||||
|
|
||||||
void addConstraint(Constraint* Constraint);
|
void addConstraint(Constraint* Constraint);
|
||||||
|
|
||||||
Type* inferExpression(Expression* Expression, InferContext& Ctx);
|
void forwardDeclare(Node* Node);
|
||||||
Type* inferTypeExpression(TypeExpression* TE, InferContext& Ctx);
|
|
||||||
|
|
||||||
void inferBindings(Pattern* Pattern, Type* T, InferContext& Ctx, ConstraintSet& Constraints, TVSet& Tvs);
|
Type* inferExpression(Expression* Expression);
|
||||||
|
Type* inferTypeExpression(TypeExpression* TE);
|
||||||
|
|
||||||
void infer(Node* node, InferContext& Ctx);
|
void inferBindings(Pattern* Pattern, Type* T, ConstraintSet& Constraints, TVSet& Tvs);
|
||||||
|
|
||||||
|
void infer(Node* node);
|
||||||
|
|
||||||
TCon* createPrimConType();
|
TCon* createPrimConType();
|
||||||
|
|
||||||
TVar* createTypeVar(InferContext& Ctx);
|
TVar* createTypeVar();
|
||||||
|
|
||||||
Type* instantiate(Scheme& S, InferContext& Ctx, Node* Source);
|
void addBinding(ByteString Name, Scheme Scm);
|
||||||
|
|
||||||
|
Type* lookupMono(ByteString Name);
|
||||||
|
|
||||||
|
InferContext* lookupCall(Node* Source, SymbolPath Path);
|
||||||
|
|
||||||
|
Type* getReturnType();
|
||||||
|
|
||||||
|
Scheme* lookup(ByteString Name);
|
||||||
|
|
||||||
|
Type* instantiate(Scheme& S, Node* Source);
|
||||||
|
|
||||||
bool unify(Type* A, Type* B, TVSub& Solution);
|
bool unify(Type* A, Type* B, TVSub& Solution);
|
||||||
|
|
||||||
void solve(Constraint* Constraint);
|
void solve(Constraint* Constraint, TVSub& Solution);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
Checker(DiagnosticEngine& DE);
|
Checker(DiagnosticEngine& DE);
|
||||||
|
|
||||||
void check(SourceFile* SF);
|
TVSub check(SourceFile* SF);
|
||||||
|
|
||||||
|
inline Type* getBoolType() {
|
||||||
|
return BoolType;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Type* getStringType() {
|
||||||
|
return StringType;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Type* getIntType() {
|
||||||
|
return IntType;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type* getType(Node* Node, const TVSub& Solution);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
52
src/CST.cc
52
src/CST.cc
|
@ -5,6 +5,25 @@
|
||||||
|
|
||||||
namespace bolt {
|
namespace bolt {
|
||||||
|
|
||||||
|
Node* Scope::lookup(SymbolPath Path) {
|
||||||
|
auto Curr = this;
|
||||||
|
do {
|
||||||
|
auto Match = Curr->Mapping.find(Path.Name);
|
||||||
|
if (Match != Curr->Mapping.end()) {
|
||||||
|
return Match->second;
|
||||||
|
}
|
||||||
|
Curr = Curr->getParentScope();
|
||||||
|
} while (Curr != nullptr);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope* Scope::getParentScope() {
|
||||||
|
if (Source->Parent == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return Source->Parent->getScope();
|
||||||
|
}
|
||||||
|
|
||||||
SourceFile* Node::getSourceFile() {
|
SourceFile* Node::getSourceFile() {
|
||||||
auto CurrNode = this;
|
auto CurrNode = this;
|
||||||
for (;;) {
|
for (;;) {
|
||||||
|
@ -23,6 +42,10 @@ namespace bolt {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Scope* Node::getScope() {
|
||||||
|
return this->Parent->getScope();
|
||||||
|
}
|
||||||
|
|
||||||
TextLoc Token::getEndLoc() {
|
TextLoc Token::getEndLoc() {
|
||||||
auto EndLoc = StartLoc;
|
auto EndLoc = StartLoc;
|
||||||
EndLoc.advance(getText());
|
EndLoc.advance(getText());
|
||||||
|
@ -61,6 +84,13 @@ namespace bolt {
|
||||||
Name->Parent = this;
|
Name->Parent = this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NestedExpression::setParents() {
|
||||||
|
LParen->Parent = this;
|
||||||
|
Inner->Parent = this;
|
||||||
|
Inner->setParents();
|
||||||
|
RParen->Parent = this;
|
||||||
|
}
|
||||||
|
|
||||||
void ConstantExpression::setParents() {
|
void ConstantExpression::setParents() {
|
||||||
Token->Parent = this;
|
Token->Parent = this;
|
||||||
}
|
}
|
||||||
|
@ -330,6 +360,12 @@ namespace bolt {
|
||||||
Name->unref();
|
Name->unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NestedExpression::~NestedExpression() {
|
||||||
|
LParen->unref();
|
||||||
|
Inner->unref();
|
||||||
|
RParen->unref();
|
||||||
|
}
|
||||||
|
|
||||||
ConstantExpression::~ConstantExpression() {
|
ConstantExpression::~ConstantExpression() {
|
||||||
Token->unref();
|
Token->unref();
|
||||||
}
|
}
|
||||||
|
@ -493,6 +529,14 @@ namespace bolt {
|
||||||
return Name->getLastToken();
|
return Name->getLastToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Token* NestedExpression::getFirstToken() {
|
||||||
|
return LParen;
|
||||||
|
}
|
||||||
|
|
||||||
|
Token* NestedExpression::getLastToken() {
|
||||||
|
return RParen;
|
||||||
|
}
|
||||||
|
|
||||||
Token* ConstantExpression::getFirstToken() {
|
Token* ConstantExpression::getFirstToken() {
|
||||||
return Token;
|
return Token;
|
||||||
}
|
}
|
||||||
|
@ -786,5 +830,13 @@ namespace bolt {
|
||||||
return "..";
|
return "..";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SymbolPath QualifiedName::getSymbolPath() const {
|
||||||
|
std::vector<ByteString> ModuleNames;
|
||||||
|
for (auto Ident: ModulePath) {
|
||||||
|
ModuleNames.push_back(Ident->Text);
|
||||||
|
}
|
||||||
|
return SymbolPath { ModuleNames, Name->Text };
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
417
src/Checker.cc
417
src/Checker.cc
|
@ -11,6 +11,41 @@ namespace bolt {
|
||||||
|
|
||||||
std::string describe(const Type* Ty);
|
std::string describe(const Type* Ty);
|
||||||
|
|
||||||
|
void Type::addTypeVars(TVSet& TVs) {
|
||||||
|
switch (Kind) {
|
||||||
|
case TypeKind::Var:
|
||||||
|
TVs.emplace(static_cast<TVar*>(this));
|
||||||
|
break;
|
||||||
|
case TypeKind::Arrow:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<TArrow*>(this);
|
||||||
|
for (auto Ty: Y->ParamTypes) {
|
||||||
|
Ty->addTypeVars(TVs);
|
||||||
|
}
|
||||||
|
Y->ReturnType->addTypeVars(TVs);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TypeKind::Con:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<TCon*>(this);
|
||||||
|
for (auto Ty: Y->Args) {
|
||||||
|
Ty->addTypeVars(TVs);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TypeKind::Tuple:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<TTuple*>(this);
|
||||||
|
for (auto Ty: Y->ElementTypes) {
|
||||||
|
Ty->addTypeVars(TVs);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TypeKind::Any:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool Type::hasTypeVar(const TVar* TV) {
|
bool Type::hasTypeVar(const TVar* TV) {
|
||||||
switch (Kind) {
|
switch (Kind) {
|
||||||
case TypeKind::Var:
|
case TypeKind::Var:
|
||||||
|
@ -61,32 +96,50 @@ namespace bolt {
|
||||||
case TypeKind::Arrow:
|
case TypeKind::Arrow:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<TArrow*>(this);
|
auto Y = static_cast<TArrow*>(this);
|
||||||
|
bool Changed = false;
|
||||||
std::vector<Type*> NewParamTypes;
|
std::vector<Type*> NewParamTypes;
|
||||||
for (auto Ty: Y->ParamTypes) {
|
for (auto Ty: Y->ParamTypes) {
|
||||||
NewParamTypes.push_back(Ty->substitute(Sub));
|
auto NewParamType = Ty->substitute(Sub);
|
||||||
|
if (NewParamType != Ty) {
|
||||||
|
Changed = true;
|
||||||
|
}
|
||||||
|
NewParamTypes.push_back(NewParamType);
|
||||||
}
|
}
|
||||||
auto NewRetTy = Y->ReturnType->substitute(Sub) ;
|
auto NewRetTy = Y->ReturnType->substitute(Sub) ;
|
||||||
return new TArrow(NewParamTypes, NewRetTy);
|
if (NewRetTy != Y->ReturnType) {
|
||||||
|
Changed = true;
|
||||||
|
}
|
||||||
|
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
|
||||||
}
|
}
|
||||||
case TypeKind::Any:
|
case TypeKind::Any:
|
||||||
return this;
|
return this;
|
||||||
case TypeKind::Con:
|
case TypeKind::Con:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<TCon*>(this);
|
auto Y = static_cast<TCon*>(this);
|
||||||
|
bool Changed = false;
|
||||||
std::vector<Type*> NewArgs;
|
std::vector<Type*> NewArgs;
|
||||||
for (auto Arg: Y->Args) {
|
for (auto Arg: Y->Args) {
|
||||||
NewArgs.push_back(Arg->substitute(Sub));
|
auto NewArg = Arg->substitute(Sub);
|
||||||
|
if (NewArg != Arg) {
|
||||||
|
Changed = true;
|
||||||
|
}
|
||||||
|
NewArgs.push_back(NewArg);
|
||||||
}
|
}
|
||||||
return new TCon(Y->Id, NewArgs, Y->DisplayName);
|
return Changed ? new TCon(Y->Id, NewArgs, Y->DisplayName) : this;
|
||||||
}
|
}
|
||||||
case TypeKind::Tuple:
|
case TypeKind::Tuple:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<TTuple*>(this);
|
auto Y = static_cast<TTuple*>(this);
|
||||||
|
bool Changed = false;
|
||||||
std::vector<Type*> NewElementTypes;
|
std::vector<Type*> NewElementTypes;
|
||||||
for (auto Ty: Y->ElementTypes) {
|
for (auto Ty: Y->ElementTypes) {
|
||||||
NewElementTypes.push_back(Ty->substitute(Sub));
|
auto NewElementType = Ty->substitute(Sub);
|
||||||
|
if (NewElementType != Ty) {
|
||||||
|
Changed = true;
|
||||||
|
}
|
||||||
|
NewElementTypes.push_back(NewElementType);
|
||||||
}
|
}
|
||||||
return new TTuple(NewElementTypes);
|
return Changed ? new TTuple(NewElementTypes) : this;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,7 +155,7 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto Y = static_cast<CMany*>(this);
|
auto Y = static_cast<CMany*>(this);
|
||||||
auto NewConstraints = new ConstraintSet();
|
auto NewConstraints = new ConstraintSet();
|
||||||
for (auto Element: Y->Constraints) {
|
for (auto Element: Y->Elements) {
|
||||||
NewConstraints->push_back(Element->substitute(Sub));
|
NewConstraints->push_back(Element->substitute(Sub));
|
||||||
}
|
}
|
||||||
return new CMany(*NewConstraints);
|
return new CMany(*NewConstraints);
|
||||||
|
@ -112,21 +165,25 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Scheme* InferContext::lookup(ByteString Name) {
|
Checker::Checker(DiagnosticEngine& DE):
|
||||||
InferContext* Curr = this;
|
DE(DE) {
|
||||||
for (;;) {
|
BoolType = new TCon(nextConTypeId++, {}, "Bool");
|
||||||
|
IntType = new TCon(nextConTypeId++, {}, "Int");
|
||||||
|
StringType = new TCon(nextConTypeId++, {}, "String");
|
||||||
|
}
|
||||||
|
|
||||||
|
Scheme* Checker::lookup(ByteString Name) {
|
||||||
|
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
|
||||||
|
auto Curr = *Iter;
|
||||||
auto Match = Curr->Env.find(Name);
|
auto Match = Curr->Env.find(Name);
|
||||||
if (Match != Curr->Env.end()) {
|
if (Match != Curr->Env.end()) {
|
||||||
return &Match->second;
|
return &Match->second;
|
||||||
}
|
}
|
||||||
Curr = Curr->Parent;
|
|
||||||
if (Curr == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* InferContext::lookupMono(ByteString Name) {
|
Type* Checker::lookupMono(ByteString Name) {
|
||||||
auto Scm = lookup(Name);
|
auto Scm = lookup(Name);
|
||||||
if (Scm == nullptr) {
|
if (Scm == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -136,22 +193,126 @@ namespace bolt {
|
||||||
return F.Type;
|
return F.Type;
|
||||||
}
|
}
|
||||||
|
|
||||||
void InferContext::addBinding(ByteString Name, Scheme S) {
|
void Checker::addBinding(ByteString Name, Scheme S) {
|
||||||
Env.emplace(Name, S);
|
Contexts.back()->Env.emplace(Name, S);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InferContext::addConstraint(Constraint *C) {
|
Type* Checker::getReturnType() {
|
||||||
Constraints.push_back(C);
|
auto Ty = Contexts.back()->ReturnType;
|
||||||
|
ZEN_ASSERT(Ty != nullptr);
|
||||||
|
return Ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
Checker::Checker(DiagnosticEngine& DE):
|
static bool hasTypeVar(TVSet& Set, Type* Type) {
|
||||||
DE(DE) {
|
for (auto TV: Type->getTypeVars()) {
|
||||||
BoolType = new TCon(nextConTypeId++, {}, "Bool");
|
if (Set.count(TV)) {
|
||||||
IntType = new TCon(nextConTypeId++, {}, "Int");
|
return true;
|
||||||
StringType = new TCon(nextConTypeId++, {}, "String");
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Checker::addConstraint(Constraint* Constraint) {
|
||||||
|
switch (Constraint->getKind()) {
|
||||||
|
case ConstraintKind::Equal:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<CEqual*>(Constraint);
|
||||||
|
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
|
||||||
|
auto& Ctx = **Iter;
|
||||||
|
if (hasTypeVar(Ctx.TVs, Y->Left) || hasTypeVar(Ctx.TVs, Y->Right)) {
|
||||||
|
Ctx.Constraints.push_back(Constraint);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Contexts.front()->Constraints.push_back(Constraint);
|
||||||
|
//auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth);
|
||||||
|
//ZEN_ASSERT(I < Contexts.size());
|
||||||
|
//auto Ctx = Contexts[I];
|
||||||
|
//Ctx->Constraints.push_back(Constraint);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ConstraintKind::Many:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<CMany*>(Constraint);
|
||||||
|
for (auto Element: Y->Elements) {
|
||||||
|
addConstraint(Element);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ConstraintKind::Empty:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Checker::forwardDeclare(Node* X) {
|
||||||
|
|
||||||
|
switch (X->Type) {
|
||||||
|
|
||||||
|
case NodeType::ExpressionStatement:
|
||||||
|
case NodeType::ReturnStatement:
|
||||||
|
case NodeType::IfStatement:
|
||||||
|
break;
|
||||||
|
|
||||||
|
case NodeType::SourceFile:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<SourceFile*>(X);
|
||||||
|
for (auto Element: Y->Elements) {
|
||||||
|
forwardDeclare(Element) ;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case NodeType::LetDeclaration:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<LetDeclaration*>(X);
|
||||||
|
|
||||||
|
auto NewCtx = new InferContext();
|
||||||
|
Y->Ctx = NewCtx;
|
||||||
|
std::cerr << Y << std::endl;
|
||||||
|
|
||||||
|
Contexts.push_back(NewCtx);
|
||||||
|
|
||||||
|
Type* Ty;
|
||||||
|
if (Y->TypeAssert) {
|
||||||
|
Ty = inferTypeExpression(Y->TypeAssert->TypeExpression);
|
||||||
|
} else {
|
||||||
|
Ty = createTypeVar();
|
||||||
|
}
|
||||||
|
Y->Ty = Ty;
|
||||||
|
|
||||||
|
if (Y->Body) {
|
||||||
|
switch (Y->Body->Type) {
|
||||||
|
case NodeType::LetExprBody:
|
||||||
|
break;
|
||||||
|
case NodeType::LetBlockBody:
|
||||||
|
{
|
||||||
|
auto Z = static_cast<LetBlockBody*>(Y->Body);
|
||||||
|
for (auto Element: Z->Elements) {
|
||||||
|
forwardDeclare(Element);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ZEN_UNREACHABLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Contexts.pop_back();
|
||||||
|
|
||||||
|
inferBindings(Y->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs);
|
||||||
|
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
ZEN_UNREACHABLE
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::infer(Node* X, InferContext& Ctx) {
|
}
|
||||||
|
|
||||||
|
void Checker::infer(Node* X) {
|
||||||
|
|
||||||
switch (X->Type) {
|
switch (X->Type) {
|
||||||
|
|
||||||
|
@ -159,7 +320,7 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto Y = static_cast<SourceFile*>(X);
|
auto Y = static_cast<SourceFile*>(X);
|
||||||
for (auto Element: Y->Elements) {
|
for (auto Element: Y->Elements) {
|
||||||
infer(Element, Ctx);
|
infer(Element);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -169,10 +330,10 @@ namespace bolt {
|
||||||
auto Y = static_cast<IfStatement*>(X);
|
auto Y = static_cast<IfStatement*>(X);
|
||||||
for (auto Part: Y->Parts) {
|
for (auto Part: Y->Parts) {
|
||||||
if (Part->Test != nullptr) {
|
if (Part->Test != nullptr) {
|
||||||
Ctx.addConstraint(new CEqual { BoolType, inferExpression(Part->Test, Ctx), Part->Test });
|
addConstraint(new CEqual { BoolType, inferExpression(Part->Test), Part->Test });
|
||||||
}
|
}
|
||||||
for (auto Element: Part->Elements) {
|
for (auto Element: Part->Elements) {
|
||||||
infer(Element, Ctx);
|
infer(Element);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -182,24 +343,18 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto Y = static_cast<LetDeclaration*>(X);
|
auto Y = static_cast<LetDeclaration*>(X);
|
||||||
|
|
||||||
auto NewCtx = new InferContext { Ctx };
|
auto NewCtx = Y->Ctx;
|
||||||
|
Contexts.push_back(NewCtx);
|
||||||
Type* Ty;
|
|
||||||
if (Y->TypeAssert) {
|
|
||||||
Ty = inferTypeExpression(Y->TypeAssert->TypeExpression, *NewCtx);
|
|
||||||
} else {
|
|
||||||
Ty = createTypeVar(*NewCtx);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Type*> ParamTypes;
|
std::vector<Type*> ParamTypes;
|
||||||
Type* RetType;
|
Type* RetType;
|
||||||
|
|
||||||
for (auto Param: Y->Params) {
|
for (auto Param: Y->Params) {
|
||||||
// TODO incorporate Param->TypeAssert or make it a kind of pattern
|
// TODO incorporate Param->TypeAssert or make it a kind of pattern
|
||||||
TVar* TV = createTypeVar(*NewCtx);
|
TVar* TV = createTypeVar();
|
||||||
TVSet NoTVs;
|
TVSet NoTVs;
|
||||||
ConstraintSet NoConstraints;
|
ConstraintSet NoConstraints;
|
||||||
inferBindings(Param->Pattern, TV, *NewCtx, NoConstraints, NoTVs);
|
inferBindings(Param->Pattern, TV, NoConstraints, NoTVs);
|
||||||
ParamTypes.push_back(TV);
|
ParamTypes.push_back(TV);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,16 +363,16 @@ namespace bolt {
|
||||||
case NodeType::LetExprBody:
|
case NodeType::LetExprBody:
|
||||||
{
|
{
|
||||||
auto Z = static_cast<LetExprBody*>(Y->Body);
|
auto Z = static_cast<LetExprBody*>(Y->Body);
|
||||||
RetType = inferExpression(Z->Expression, *NewCtx);
|
RetType = inferExpression(Z->Expression);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case NodeType::LetBlockBody:
|
case NodeType::LetBlockBody:
|
||||||
{
|
{
|
||||||
auto Z = static_cast<LetBlockBody*>(Y->Body);
|
auto Z = static_cast<LetBlockBody*>(Y->Body);
|
||||||
RetType = createTypeVar(*NewCtx);
|
RetType = createTypeVar();
|
||||||
NewCtx->ReturnType = RetType;
|
NewCtx->ReturnType = RetType;
|
||||||
for (auto Element: Z->Elements) {
|
for (auto Element: Z->Elements) {
|
||||||
infer(Element, *NewCtx);
|
infer(Element);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -225,12 +380,12 @@ namespace bolt {
|
||||||
ZEN_UNREACHABLE
|
ZEN_UNREACHABLE
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
RetType = createTypeVar(*NewCtx);
|
RetType = createTypeVar();
|
||||||
}
|
}
|
||||||
|
|
||||||
NewCtx->addConstraint(new CEqual { Ty, new TArrow(ParamTypes, RetType), X });
|
addConstraint(new CEqual { Y->Ty, new TArrow(ParamTypes, RetType), X });
|
||||||
|
|
||||||
inferBindings(Y->Pattern, Ty, Ctx, NewCtx->Constraints, NewCtx->TVs);
|
Contexts.pop_back();
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -240,19 +395,18 @@ namespace bolt {
|
||||||
auto Y = static_cast<ReturnStatement*>(X);
|
auto Y = static_cast<ReturnStatement*>(X);
|
||||||
Type* ReturnType;
|
Type* ReturnType;
|
||||||
if (Y->Expression) {
|
if (Y->Expression) {
|
||||||
ReturnType = inferExpression(Y->Expression, Ctx);
|
ReturnType = inferExpression(Y->Expression);
|
||||||
} else {
|
} else {
|
||||||
ReturnType = new TTuple({});
|
ReturnType = new TTuple({});
|
||||||
}
|
}
|
||||||
ZEN_ASSERT(Ctx.ReturnType != nullptr);
|
addConstraint(new CEqual { ReturnType, getReturnType(), X });
|
||||||
Ctx.addConstraint(new CEqual { ReturnType, Ctx.ReturnType, X });
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case NodeType::ExpressionStatement:
|
case NodeType::ExpressionStatement:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<ExpressionStatement*>(X);
|
auto Y = static_cast<ExpressionStatement*>(X);
|
||||||
inferExpression(Y->Expression, Ctx);
|
inferExpression(Y->Expression);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,13 +417,13 @@ namespace bolt {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TVar* Checker::createTypeVar(InferContext& Ctx) {
|
TVar* Checker::createTypeVar() {
|
||||||
auto TV = new TVar(nextTypeVarId++);
|
auto TV = new TVar(nextTypeVarId++);
|
||||||
Ctx.TVs.emplace(TV);
|
Contexts.back()->TVs.emplace(TV);
|
||||||
return TV;
|
return TV;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* Checker::instantiate(Scheme& S, InferContext& Ctx, Node* Source) {
|
Type* Checker::instantiate(Scheme& S, Node* Source) {
|
||||||
|
|
||||||
switch (S.getKind()) {
|
switch (S.getKind()) {
|
||||||
|
|
||||||
|
@ -278,47 +432,46 @@ namespace bolt {
|
||||||
auto& F = S.as<Forall>();
|
auto& F = S.as<Forall>();
|
||||||
|
|
||||||
TVSub Sub;
|
TVSub Sub;
|
||||||
if (F.TVs) {
|
for (auto TV: *F.TVs) {
|
||||||
for (auto TV: *F.TVs) {
|
Sub[TV] = createTypeVar();
|
||||||
Sub[TV] = createTypeVar(Ctx);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (F.Constraints) {
|
for (auto Constraint: *F.Constraints) {
|
||||||
|
|
||||||
for (auto Constraint: *F.Constraints) {
|
auto NewConstraint = Constraint->substitute(Sub);
|
||||||
|
|
||||||
auto NewConstraint = Constraint->substitute(Sub);
|
// This makes error messages prettier by relating the typing failure
|
||||||
|
// to the call site rather than the definition.
|
||||||
// This makes error messages prettier by relating the typing failure
|
if (NewConstraint->getKind() == ConstraintKind::Equal) {
|
||||||
// to the call site rather than the definition.
|
static_cast<CEqual *>(NewConstraint)->Source = Source;
|
||||||
if (NewConstraint->getKind() == ConstraintKind::Equal) {
|
|
||||||
static_cast<CEqual *>(NewConstraint)->Source = Source;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ctx.addConstraint(NewConstraint);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addConstraint(NewConstraint);
|
||||||
}
|
}
|
||||||
|
|
||||||
return F.Type->substitute(Sub);
|
// FIXME substitute should always clone if we set MaxDepth
|
||||||
|
auto NewType = F.Type->substitute(Sub);
|
||||||
|
//NewType->MaxDepth = std::max(static_cast<unsigned>(Contexts.size()-1), F.Type->MaxDepth);
|
||||||
|
return NewType;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* Checker::inferTypeExpression(TypeExpression* X, InferContext& Ctx) {
|
Type* Checker::inferTypeExpression(TypeExpression* X) {
|
||||||
|
|
||||||
switch (X->Type) {
|
switch (X->Type) {
|
||||||
|
|
||||||
case NodeType::ReferenceTypeExpression:
|
case NodeType::ReferenceTypeExpression:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<ReferenceTypeExpression*>(X);
|
auto Y = static_cast<ReferenceTypeExpression*>(X);
|
||||||
auto Ty = Ctx.lookupMono(Y->Name->Name->Text);
|
auto Ty = lookupMono(Y->Name->Name->Text);
|
||||||
if (Ty == nullptr) {
|
if (Ty == nullptr) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name->Name);
|
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name->Name);
|
||||||
return new TAny();
|
return new TAny();
|
||||||
}
|
}
|
||||||
|
Mapping[X] = Ty;
|
||||||
return Ty;
|
return Ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -327,10 +480,12 @@ namespace bolt {
|
||||||
auto Y = static_cast<ArrowTypeExpression*>(X);
|
auto Y = static_cast<ArrowTypeExpression*>(X);
|
||||||
std::vector<Type*> ParamTypes;
|
std::vector<Type*> ParamTypes;
|
||||||
for (auto ParamType: Y->ParamTypes) {
|
for (auto ParamType: Y->ParamTypes) {
|
||||||
ParamTypes.push_back(inferTypeExpression(ParamType, Ctx));
|
ParamTypes.push_back(inferTypeExpression(ParamType));
|
||||||
}
|
}
|
||||||
auto ReturnType = inferTypeExpression(Y->ReturnType, Ctx);
|
auto ReturnType = inferTypeExpression(Y->ReturnType);
|
||||||
return new TArrow(ParamTypes, ReturnType);
|
auto Ty = new TArrow(ParamTypes, ReturnType);
|
||||||
|
Mapping[X] = Ty;
|
||||||
|
return Ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -339,7 +494,7 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* Checker::inferExpression(Expression* X, InferContext& Ctx) {
|
Type* Checker::inferExpression(Expression* X) {
|
||||||
|
|
||||||
switch (X->Type) {
|
switch (X->Type) {
|
||||||
|
|
||||||
|
@ -349,15 +504,16 @@ namespace bolt {
|
||||||
Type* Ty = nullptr;
|
Type* Ty = nullptr;
|
||||||
switch (Y->Token->Type) {
|
switch (Y->Token->Type) {
|
||||||
case NodeType::IntegerLiteral:
|
case NodeType::IntegerLiteral:
|
||||||
Ty = Ctx.lookupMono("Int");
|
Ty = lookupMono("Int");
|
||||||
break;
|
break;
|
||||||
case NodeType::StringLiteral:
|
case NodeType::StringLiteral:
|
||||||
Ty = Ctx.lookupMono("String");
|
Ty = lookupMono("String");
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
ZEN_UNREACHABLE
|
ZEN_UNREACHABLE
|
||||||
}
|
}
|
||||||
ZEN_ASSERT(Ty != nullptr);
|
ZEN_ASSERT(Ty != nullptr);
|
||||||
|
Mapping[X] = Ty;
|
||||||
return Ty;
|
return Ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,44 +521,58 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto Y = static_cast<ReferenceExpression*>(X);
|
auto Y = static_cast<ReferenceExpression*>(X);
|
||||||
ZEN_ASSERT(Y->Name->ModulePath.empty());
|
ZEN_ASSERT(Y->Name->ModulePath.empty());
|
||||||
auto Scm = Ctx.lookup(Y->Name->Name->Text);
|
auto Ctx = lookupCall(Y, Y->Name->getSymbolPath());
|
||||||
|
if (Ctx) {
|
||||||
|
return Ctx->ReturnType;
|
||||||
|
}
|
||||||
|
auto Scm = lookup(Y->Name->Name->Text);
|
||||||
if (Scm == nullptr) {
|
if (Scm == nullptr) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name);
|
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name);
|
||||||
return new TAny();
|
return new TAny();
|
||||||
}
|
}
|
||||||
return instantiate(*Scm, Ctx, X);
|
auto Ty = instantiate(*Scm, X);
|
||||||
|
Mapping[X] = Ty;
|
||||||
|
return Ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
case NodeType::CallExpression:
|
case NodeType::CallExpression:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<CallExpression*>(X);
|
auto Y = static_cast<CallExpression*>(X);
|
||||||
auto OpTy = inferExpression(Y->Function, Ctx);
|
auto OpTy = inferExpression(Y->Function);
|
||||||
auto RetType = createTypeVar(Ctx);
|
auto RetType = createTypeVar();
|
||||||
std::vector<Type*> ArgTypes;
|
std::vector<Type*> ArgTypes;
|
||||||
for (auto Arg: Y->Args) {
|
for (auto Arg: Y->Args) {
|
||||||
ArgTypes.push_back(inferExpression(Arg, Ctx));
|
ArgTypes.push_back(inferExpression(Arg));
|
||||||
}
|
}
|
||||||
Ctx.addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X });
|
addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X });
|
||||||
|
Mapping[X] = RetType;
|
||||||
return RetType;
|
return RetType;
|
||||||
}
|
}
|
||||||
|
|
||||||
case NodeType::InfixExpression:
|
case NodeType::InfixExpression:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<InfixExpression*>(X);
|
auto Y = static_cast<InfixExpression*>(X);
|
||||||
auto Scm = Ctx.lookup(Y->Operator->getText());
|
auto Scm = lookup(Y->Operator->getText());
|
||||||
if (Scm == nullptr) {
|
if (Scm == nullptr) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator);
|
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator);
|
||||||
return new TAny();
|
return new TAny();
|
||||||
}
|
}
|
||||||
auto OpTy = instantiate(*Scm, Ctx, Y->Operator);
|
auto OpTy = instantiate(*Scm, Y->Operator);
|
||||||
auto RetTy = createTypeVar(Ctx);
|
auto RetTy = createTypeVar();
|
||||||
std::vector<Type*> ArgTys;
|
std::vector<Type*> ArgTys;
|
||||||
ArgTys.push_back(inferExpression(Y->LHS, Ctx));
|
ArgTys.push_back(inferExpression(Y->LHS));
|
||||||
ArgTys.push_back(inferExpression(Y->RHS, Ctx));
|
ArgTys.push_back(inferExpression(Y->RHS));
|
||||||
Ctx.addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X });
|
addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X });
|
||||||
|
Mapping[X] = RetTy;
|
||||||
return RetTy;
|
return RetTy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case NodeType::NestedExpression:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<NestedExpression*>(X);
|
||||||
|
return inferExpression(Y->Inner);
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
ZEN_UNREACHABLE
|
ZEN_UNREACHABLE
|
||||||
|
|
||||||
|
@ -410,12 +580,12 @@ namespace bolt {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::inferBindings(Pattern* Pattern, Type* Type, InferContext& Ctx, ConstraintSet& Constraints, TVSet& TVs) {
|
void Checker::inferBindings(Pattern* Pattern, Type* Type, ConstraintSet& Constraints, TVSet& TVs) {
|
||||||
|
|
||||||
switch (Pattern->Type) {
|
switch (Pattern->Type) {
|
||||||
|
|
||||||
case NodeType::BindPattern:
|
case NodeType::BindPattern:
|
||||||
Ctx.addBinding(static_cast<BindPattern*>(Pattern)->Name->Text, Forall(TVs, Constraints, Type));
|
addBinding(static_cast<BindPattern*>(Pattern)->Name->Text, Forall(TVs, Constraints, Type));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -424,26 +594,33 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::check(SourceFile *SF) {
|
TVSub Checker::check(SourceFile *SF) {
|
||||||
InferContext Toplevel;
|
Contexts.push_back(new InferContext {});
|
||||||
Toplevel.addBinding("String", Forall(StringType));
|
ConstraintSet NoConstraints;
|
||||||
Toplevel.addBinding("Int", Forall(IntType));
|
addBinding("String", Forall(StringType));
|
||||||
Toplevel.addBinding("Bool", Forall(BoolType));
|
addBinding("Int", Forall(IntType));
|
||||||
Toplevel.addBinding("True", Forall(BoolType));
|
addBinding("Bool", Forall(BoolType));
|
||||||
Toplevel.addBinding("False", Forall(BoolType));
|
addBinding("True", Forall(BoolType));
|
||||||
Toplevel.addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType)));
|
addBinding("False", Forall(BoolType));
|
||||||
Toplevel.addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType)));
|
auto A = createTypeVar();
|
||||||
Toplevel.addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType)));
|
TVSet SingleA { A };
|
||||||
Toplevel.addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType)));
|
addBinding("==", Forall(SingleA, NoConstraints, new TArrow({ A, A }, BoolType)));
|
||||||
infer(SF, Toplevel);
|
addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||||
solve(new CMany(Toplevel.Constraints));
|
addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||||
|
addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||||
|
addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||||
|
forwardDeclare(SF);
|
||||||
|
infer(SF);
|
||||||
|
TVSub Solution;
|
||||||
|
solve(new CMany(Contexts.front()->Constraints), Solution);
|
||||||
|
Contexts.pop_back();
|
||||||
|
return Solution;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::solve(Constraint* Constraint) {
|
void Checker::solve(Constraint* Constraint, TVSub& Solution) {
|
||||||
|
|
||||||
std::stack<class Constraint*> Queue;
|
std::stack<class Constraint*> Queue;
|
||||||
Queue.push(Constraint);
|
Queue.push(Constraint);
|
||||||
TVSub Solution;
|
|
||||||
|
|
||||||
while (!Queue.empty()) {
|
while (!Queue.empty()) {
|
||||||
|
|
||||||
|
@ -459,7 +636,7 @@ namespace bolt {
|
||||||
case ConstraintKind::Many:
|
case ConstraintKind::Many:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<CMany*>(Constraint);
|
auto Y = static_cast<CMany*>(Constraint);
|
||||||
for (auto Constraint: Y->Constraints) {
|
for (auto Constraint: Y->Elements) {
|
||||||
Queue.push(Constraint);
|
Queue.push(Constraint);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -468,7 +645,7 @@ namespace bolt {
|
||||||
case ConstraintKind::Equal:
|
case ConstraintKind::Equal:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<CEqual*>(Constraint);
|
auto Y = static_cast<CEqual*>(Constraint);
|
||||||
std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl;
|
//std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl;
|
||||||
if (!unify(Y->Left, Y->Right, Solution)) {
|
if (!unify(Y->Left, Y->Right, Solution)) {
|
||||||
DE.add<UnificationErrorDiagnostic>(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source);
|
DE.add<UnificationErrorDiagnostic>(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source);
|
||||||
}
|
}
|
||||||
|
@ -530,6 +707,17 @@ namespace bolt {
|
||||||
return unify(Y->ReturnType, Z->ReturnType, Solution);
|
return unify(Y->ReturnType, Z->ReturnType, Solution);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (A->getKind() == TypeKind::Arrow) {
|
||||||
|
auto Y = static_cast<TArrow*>(A);
|
||||||
|
if (Y->ParamTypes.empty()) {
|
||||||
|
return unify(Y->ReturnType, B, Solution);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (B->getKind() == TypeKind::Arrow) {
|
||||||
|
return unify(B, A, Solution);
|
||||||
|
}
|
||||||
|
|
||||||
if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) {
|
if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) {
|
||||||
auto Y = static_cast<TTuple*>(A);
|
auto Y = static_cast<TTuple*>(A);
|
||||||
auto Z = static_cast<TTuple*>(B);
|
auto Z = static_cast<TTuple*>(B);
|
||||||
|
@ -565,5 +753,22 @@ namespace bolt {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
InferContext* Checker::lookupCall(Node* Source, SymbolPath Path) {
|
||||||
|
auto Def = Source->getScope()->lookup(Path);
|
||||||
|
auto Match = CallGraph.find(Def);
|
||||||
|
if (Match == CallGraph.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return Match->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type* Checker::getType(Node *Node, const TVSub &Solution) {
|
||||||
|
auto Match = Mapping.find(Node);
|
||||||
|
if (Match == Mapping.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return Match->second->substitute(Solution);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -148,6 +148,13 @@ namespace bolt {
|
||||||
auto Name = parseQualifiedName();
|
auto Name = parseQualifiedName();
|
||||||
return new ReferenceExpression(Name);
|
return new ReferenceExpression(Name);
|
||||||
}
|
}
|
||||||
|
case NodeType::LParen:
|
||||||
|
{
|
||||||
|
Tokens.get();
|
||||||
|
auto E = parseExpression();
|
||||||
|
auto T2 = static_cast<RParen*>(expectToken(NodeType::RParen));
|
||||||
|
return new NestedExpression(static_cast<LParen*>(T0), E, T2);
|
||||||
|
}
|
||||||
case NodeType::IntegerLiteral:
|
case NodeType::IntegerLiteral:
|
||||||
case NodeType::StringLiteral:
|
case NodeType::StringLiteral:
|
||||||
Tokens.get();
|
Tokens.get();
|
||||||
|
@ -162,7 +169,7 @@ namespace bolt {
|
||||||
std::vector<Expression*> Args;
|
std::vector<Expression*> Args;
|
||||||
for (;;) {
|
for (;;) {
|
||||||
auto T1 = Tokens.peek();
|
auto T1 = Tokens.peek();
|
||||||
if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) {
|
if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::RParen || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Args.push_back(parsePrimitiveExpression());
|
Args.push_back(parsePrimitiveExpression());
|
||||||
|
|
Loading…
Reference in a new issue