Improve type inference and some minor updates
This commit is contained in:
parent
45b5f113a0
commit
b4d54f025c
8 changed files with 357 additions and 93 deletions
|
@ -21,6 +21,7 @@ namespace bolt {
|
||||||
RBracket,
|
RBracket,
|
||||||
LBrace,
|
LBrace,
|
||||||
RBrace,
|
RBrace,
|
||||||
|
RArrow,
|
||||||
LetKeyword,
|
LetKeyword,
|
||||||
MutKeyword,
|
MutKeyword,
|
||||||
PubKeyword,
|
PubKeyword,
|
||||||
|
@ -40,6 +41,7 @@ namespace bolt {
|
||||||
IntegerLiteral,
|
IntegerLiteral,
|
||||||
QualifiedName,
|
QualifiedName,
|
||||||
ReferenceTypeExpression,
|
ReferenceTypeExpression,
|
||||||
|
ArrowTypeExpression,
|
||||||
BindPattern,
|
BindPattern,
|
||||||
ReferenceExpression,
|
ReferenceExpression,
|
||||||
ConstantExpression,
|
ConstantExpression,
|
||||||
|
@ -168,6 +170,18 @@ namespace bolt {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class RArrow : public Token {
|
||||||
|
public:
|
||||||
|
|
||||||
|
RArrow(TextLoc StartLoc):
|
||||||
|
Token(NodeType::RArrow, StartLoc) {}
|
||||||
|
|
||||||
|
std::string getText() const override;
|
||||||
|
|
||||||
|
~RArrow();
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
class Dot : public Token {
|
class Dot : public Token {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -528,6 +542,28 @@ namespace bolt {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ArrowTypeExpression : public TypeExpression {
|
||||||
|
public:
|
||||||
|
|
||||||
|
std::vector<TypeExpression*> ParamTypes;
|
||||||
|
TypeExpression* ReturnType;
|
||||||
|
|
||||||
|
inline ArrowTypeExpression(
|
||||||
|
std::vector<TypeExpression*> ParamTypes,
|
||||||
|
TypeExpression* ReturnType
|
||||||
|
): TypeExpression(NodeType::ArrowTypeExpression),
|
||||||
|
ParamTypes(ParamTypes),
|
||||||
|
ReturnType(ReturnType) {}
|
||||||
|
|
||||||
|
void setParents() override;
|
||||||
|
|
||||||
|
Token* getFirstToken() override;
|
||||||
|
Token* getLastToken() override;
|
||||||
|
|
||||||
|
~ArrowTypeExpression();
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
class Pattern : public Node {
|
class Pattern : public Node {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
|
@ -105,22 +105,24 @@ namespace bolt {
|
||||||
|
|
||||||
class Constraint;
|
class Constraint;
|
||||||
|
|
||||||
|
using ConstraintSet = std::vector<Constraint*>;
|
||||||
|
|
||||||
class Forall {
|
class Forall {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
TVSet TVs;
|
TVSet* TVs;
|
||||||
std::vector<Constraint*> Constraints;
|
ConstraintSet* Constraints;
|
||||||
Type* Type;
|
Type* Type;
|
||||||
|
|
||||||
inline Forall(class Type* Type):
|
inline Forall(class Type* Type):
|
||||||
Type(Type) {}
|
TVs(nullptr), Constraints(nullptr), Type(Type) {}
|
||||||
|
|
||||||
inline Forall(
|
inline Forall(
|
||||||
TVSet TVs,
|
TVSet& TVs,
|
||||||
std::vector<Constraint*> Constraints,
|
ConstraintSet& Constraints,
|
||||||
class Type* Type
|
class Type* Type
|
||||||
): TVs(TVs),
|
): TVs(&TVs),
|
||||||
Constraints(Constraints),
|
Constraints(&Constraints),
|
||||||
Type(Type) {}
|
Type(Type) {}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -184,19 +186,7 @@ namespace bolt {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TypeEnv {
|
using TypeEnv = std::unordered_map<ByteString, Scheme>;
|
||||||
|
|
||||||
std::unordered_map<ByteString, Scheme> Mapping;
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
void add(ByteString Name, Scheme S);
|
|
||||||
|
|
||||||
Scheme* lookup(ByteString Name);
|
|
||||||
|
|
||||||
Type* lookupMono(ByteString Name);
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class ConstraintKind {
|
enum class ConstraintKind {
|
||||||
Equal,
|
Equal,
|
||||||
|
@ -217,12 +207,12 @@ namespace bolt {
|
||||||
return Kind;
|
return Kind;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Constraint* substitute(const TVSub& Sub);
|
||||||
|
|
||||||
virtual ~Constraint() {}
|
virtual ~Constraint() {}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using ConstraintSet = std::vector<Constraint*>;
|
|
||||||
|
|
||||||
class CEqual : public Constraint {
|
class CEqual : public Constraint {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -238,9 +228,9 @@ namespace bolt {
|
||||||
class CMany : public Constraint {
|
class CMany : public Constraint {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
ConstraintSet Constraints;
|
ConstraintSet& Constraints;
|
||||||
|
|
||||||
inline CMany(ConstraintSet Constraints):
|
inline CMany(ConstraintSet& Constraints):
|
||||||
Constraint(ConstraintKind::Many), Constraints(Constraints) {}
|
Constraint(ConstraintKind::Many), Constraints(Constraints) {}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -254,18 +244,28 @@ namespace bolt {
|
||||||
};
|
};
|
||||||
|
|
||||||
class InferContext {
|
class InferContext {
|
||||||
|
|
||||||
ConstraintSet& Constraints;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
TypeEnv& Env;
|
TVSet TVs;
|
||||||
|
ConstraintSet Constraints;
|
||||||
|
TypeEnv Env;
|
||||||
|
|
||||||
inline InferContext(ConstraintSet& Constraints, TypeEnv& Env):
|
InferContext* Parent;
|
||||||
Constraints(Constraints), Env(Env) {}
|
|
||||||
|
inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env):
|
||||||
|
Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env) {}
|
||||||
|
|
||||||
|
inline InferContext(InferContext* Parent = nullptr):
|
||||||
|
Parent(Parent) {}
|
||||||
|
|
||||||
void addConstraint(Constraint* C);
|
void addConstraint(Constraint* C);
|
||||||
|
|
||||||
|
void addBinding(ByteString Name, Scheme Scm);
|
||||||
|
|
||||||
|
Type* lookupMono(ByteString Name);
|
||||||
|
|
||||||
|
Scheme* lookup(ByteString Name);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class Checker {
|
class Checker {
|
||||||
|
@ -275,14 +275,18 @@ namespace bolt {
|
||||||
size_t nextConTypeId = 0;
|
size_t nextConTypeId = 0;
|
||||||
size_t nextTypeVarId = 0;
|
size_t nextTypeVarId = 0;
|
||||||
|
|
||||||
Type* inferExpression(Expression* Expression, InferContext& Env);
|
Type* inferExpression(Expression* Expression, InferContext& Ctx);
|
||||||
|
Type* inferTypeExpression(TypeExpression* TE, InferContext& Ctx);
|
||||||
|
|
||||||
void infer(Node* node, InferContext& Env);
|
void inferBindings(Pattern* Pattern, Type* T, InferContext& Ctx, ConstraintSet& Constraints, TVSet& Tvs);
|
||||||
|
|
||||||
|
void infer(Node* node, InferContext& Ctx);
|
||||||
|
|
||||||
TCon* createPrimConType();
|
TCon* createPrimConType();
|
||||||
TVar* createTypeVar();
|
|
||||||
|
|
||||||
Type* instantiate(Scheme& S);
|
TVar* createTypeVar(InferContext& Ctx);
|
||||||
|
|
||||||
|
Type* instantiate(Scheme& S, InferContext& Ctx, Node* Source);
|
||||||
|
|
||||||
bool unify(Type* A, Type* B, TVSub& Solution);
|
bool unify(Type* A, Type* B, TVSub& Solution);
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,10 @@ namespace bolt {
|
||||||
|
|
||||||
Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence);
|
Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence);
|
||||||
|
|
||||||
|
TypeExpression* parsePrimitiveTypeExpression();
|
||||||
|
|
||||||
|
Expression* parsePrimitiveExpression();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
Parser(TextFile& File, Stream<Token*>& S);
|
Parser(TextFile& File, Stream<Token*>& S);
|
||||||
|
@ -86,8 +90,6 @@ namespace bolt {
|
||||||
|
|
||||||
Expression* parseUnaryExpression();
|
Expression* parseUnaryExpression();
|
||||||
|
|
||||||
Expression* parsePrimitiveExpression();
|
|
||||||
|
|
||||||
Expression* parseExpression();
|
Expression* parseExpression();
|
||||||
|
|
||||||
Expression* parseCallExpression();
|
Expression* parseCallExpression();
|
||||||
|
|
34
src/CST.cc
34
src/CST.cc
|
@ -43,6 +43,15 @@ namespace bolt {
|
||||||
Name->Parent = this;
|
Name->Parent = this;
|
||||||
Name->setParents();
|
Name->setParents();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ArrowTypeExpression::setParents() {
|
||||||
|
for (auto ParamType: ParamTypes) {
|
||||||
|
ParamType->Parent = this;
|
||||||
|
ParamType->setParents();
|
||||||
|
}
|
||||||
|
ReturnType->Parent = this;
|
||||||
|
ReturnType->setParents();
|
||||||
|
}
|
||||||
|
|
||||||
void BindPattern::setParents() {
|
void BindPattern::setParents() {
|
||||||
Name->Parent = this;
|
Name->Parent = this;
|
||||||
|
@ -179,6 +188,9 @@ namespace bolt {
|
||||||
Colon::~Colon() {
|
Colon::~Colon() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RArrow::~RArrow() {
|
||||||
|
}
|
||||||
|
|
||||||
Dot::~Dot() {
|
Dot::~Dot() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -268,6 +280,13 @@ namespace bolt {
|
||||||
Name->unref();
|
Name->unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ArrowTypeExpression::~ArrowTypeExpression() {
|
||||||
|
for (auto ParamType: ParamTypes) {
|
||||||
|
ParamType->unref();
|
||||||
|
}
|
||||||
|
ReturnType->unref();
|
||||||
|
}
|
||||||
|
|
||||||
Pattern::~Pattern() {
|
Pattern::~Pattern() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -401,6 +420,17 @@ namespace bolt {
|
||||||
return Name->getFirstToken();
|
return Name->getFirstToken();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Token* ArrowTypeExpression::getFirstToken() {
|
||||||
|
if (ParamTypes.size()) {
|
||||||
|
return ParamTypes.front()->getFirstToken();
|
||||||
|
}
|
||||||
|
return ReturnType->getFirstToken();
|
||||||
|
}
|
||||||
|
|
||||||
|
Token* ArrowTypeExpression::getLastToken() {
|
||||||
|
return ReturnType->getLastToken();
|
||||||
|
}
|
||||||
|
|
||||||
Token* BindPattern::getFirstToken() {
|
Token* BindPattern::getFirstToken() {
|
||||||
return Name;
|
return Name;
|
||||||
}
|
}
|
||||||
|
@ -573,6 +603,10 @@ namespace bolt {
|
||||||
return ":";
|
return ":";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string RArrow::getText() const {
|
||||||
|
return "->";
|
||||||
|
}
|
||||||
|
|
||||||
std::string Dot::getText() const {
|
std::string Dot::getText() const {
|
||||||
return ".";
|
return ".";
|
||||||
}
|
}
|
||||||
|
|
265
src/Checker.cc
265
src/Checker.cc
|
@ -9,27 +9,7 @@
|
||||||
|
|
||||||
namespace bolt {
|
namespace bolt {
|
||||||
|
|
||||||
Scheme* TypeEnv::lookup(ByteString Name) {
|
std::string describe(const Type* Ty);
|
||||||
auto Match = Mapping.find(Name);
|
|
||||||
if (Match == Mapping.end()) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
return &Match->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
Type* TypeEnv::lookupMono(ByteString Name) {
|
|
||||||
auto Match = Mapping.find(Name);
|
|
||||||
if (Match == Mapping.end()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto& F = Match->second.as<Forall>();
|
|
||||||
ZEN_ASSERT(F.TVs.empty());
|
|
||||||
return F.Type;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TypeEnv::add(ByteString Name, Scheme S) {
|
|
||||||
Mapping.emplace(Name, S);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Type::hasTypeVar(const TVar* TV) {
|
bool Type::hasTypeVar(const TVar* TV) {
|
||||||
switch (Kind) {
|
switch (Kind) {
|
||||||
|
@ -66,7 +46,7 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto Y = static_cast<TVar*>(this);
|
auto Y = static_cast<TVar*>(this);
|
||||||
auto Match = Sub.find(Y);
|
auto Match = Sub.find(Y);
|
||||||
return Match != Sub.end() ? Match->second : Y;
|
return Match != Sub.end() ? Match->second->substitute(Sub) : Y;
|
||||||
}
|
}
|
||||||
case TypeKind::Arrow:
|
case TypeKind::Arrow:
|
||||||
{
|
{
|
||||||
|
@ -87,11 +67,60 @@ namespace bolt {
|
||||||
for (auto Arg: Y->Args) {
|
for (auto Arg: Y->Args) {
|
||||||
NewArgs.push_back(Arg->substitute(Sub));
|
NewArgs.push_back(Arg->substitute(Sub));
|
||||||
}
|
}
|
||||||
return new TCon(Y->Id, Y->Args, Y->DisplayName);
|
return new TCon(Y->Id, NewArgs, Y->DisplayName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Constraint* Constraint::substitute(const TVSub &Sub) {
|
||||||
|
switch (Kind) {
|
||||||
|
case ConstraintKind::Equal:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<CEqual*>(this);
|
||||||
|
return new CEqual(Y->Left->substitute(Sub), Y->Right->substitute(Sub), Y->Source);
|
||||||
|
}
|
||||||
|
case ConstraintKind::Many:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<CMany*>(this);
|
||||||
|
auto NewConstraints = new ConstraintSet();
|
||||||
|
for (auto Element: Y->Constraints) {
|
||||||
|
NewConstraints->push_back(Element->substitute(Sub));
|
||||||
|
}
|
||||||
|
return new CMany(*NewConstraints);
|
||||||
|
}
|
||||||
|
case ConstraintKind::Empty:
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Scheme* InferContext::lookup(ByteString Name) {
|
||||||
|
InferContext* Curr = this;
|
||||||
|
for (;;) {
|
||||||
|
auto Match = Curr->Env.find(Name);
|
||||||
|
if (Match != Curr->Env.end()) {
|
||||||
|
return &Match->second;
|
||||||
|
}
|
||||||
|
Curr = Curr->Parent;
|
||||||
|
if (Curr == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Type* InferContext::lookupMono(ByteString Name) {
|
||||||
|
auto Scm = lookup(Name);
|
||||||
|
if (Scm == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto& F = Scm->as<Forall>();
|
||||||
|
ZEN_ASSERT(F.TVs == nullptr || F.TVs->empty());
|
||||||
|
return F.Type;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InferContext::addBinding(ByteString Name, Scheme S) {
|
||||||
|
Env.emplace(Name, S);
|
||||||
|
}
|
||||||
|
|
||||||
void InferContext::addConstraint(Constraint *C) {
|
void InferContext::addConstraint(Constraint *C) {
|
||||||
Constraints.push_back(C);
|
Constraints.push_back(C);
|
||||||
}
|
}
|
||||||
|
@ -114,7 +143,57 @@ namespace bolt {
|
||||||
|
|
||||||
case NodeType::LetDeclaration:
|
case NodeType::LetDeclaration:
|
||||||
{
|
{
|
||||||
// TODO
|
auto Y = static_cast<LetDeclaration*>(X);
|
||||||
|
|
||||||
|
auto NewCtx = new InferContext { Ctx };
|
||||||
|
|
||||||
|
Type* Ty;
|
||||||
|
if (Y->TypeAssert) {
|
||||||
|
Ty = inferTypeExpression(Y->TypeAssert->TypeExpression, *NewCtx);
|
||||||
|
} else {
|
||||||
|
Ty = createTypeVar(*NewCtx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Type*> ParamTypes;
|
||||||
|
Type* RetType;
|
||||||
|
|
||||||
|
for (auto Param: Y->Params) {
|
||||||
|
// TODO incorporate Param->TypeAssert or make it a kind of pattern
|
||||||
|
TVar* TV = createTypeVar(*NewCtx);
|
||||||
|
TVSet NoTVs;
|
||||||
|
ConstraintSet NoConstraints;
|
||||||
|
inferBindings(Param->Pattern, TV, *NewCtx, NoConstraints, NoTVs);
|
||||||
|
ParamTypes.push_back(TV);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Y->Body) {
|
||||||
|
switch (Y->Body->Type) {
|
||||||
|
case NodeType::LetExprBody:
|
||||||
|
{
|
||||||
|
auto Z = static_cast<LetExprBody*>(Y->Body);
|
||||||
|
RetType = inferExpression(Z->Expression, *NewCtx);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case NodeType::LetBlockBody:
|
||||||
|
{
|
||||||
|
auto Z = static_cast<LetBlockBody*>(Y->Body);
|
||||||
|
RetType = createTypeVar(*NewCtx);
|
||||||
|
for (auto Element: Z->Elements) {
|
||||||
|
infer(Element, *NewCtx);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ZEN_UNREACHABLE
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RetType = createTypeVar(*NewCtx);
|
||||||
|
}
|
||||||
|
|
||||||
|
NewCtx->addConstraint(new CEqual { Ty, new TArrow(ParamTypes, RetType), X });
|
||||||
|
|
||||||
|
inferBindings(Y->Pattern, Ty, Ctx, NewCtx->Constraints, NewCtx->TVs);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,21 +212,43 @@ namespace bolt {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TVar* Checker::createTypeVar() {
|
TVar* Checker::createTypeVar(InferContext& Ctx) {
|
||||||
return new TVar(nextTypeVarId++);
|
auto TV = new TVar(nextTypeVarId++);
|
||||||
|
Ctx.TVs.emplace(TV);
|
||||||
|
return TV;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type* Checker::instantiate(Scheme& S) {
|
Type* Checker::instantiate(Scheme& S, InferContext& Ctx, Node* Source) {
|
||||||
|
|
||||||
switch (S.getKind()) {
|
switch (S.getKind()) {
|
||||||
|
|
||||||
case SchemeKind::Forall:
|
case SchemeKind::Forall:
|
||||||
{
|
{
|
||||||
auto& F = S.as<Forall>();
|
auto& F = S.as<Forall>();
|
||||||
|
|
||||||
TVSub Sub;
|
TVSub Sub;
|
||||||
for (auto TV: F.TVs) {
|
if (F.TVs) {
|
||||||
Sub[TV] = createTypeVar();
|
for (auto TV: *F.TVs) {
|
||||||
|
Sub[TV] = createTypeVar(Ctx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (F.Constraints) {
|
||||||
|
|
||||||
|
for (auto Constraint: *F.Constraints) {
|
||||||
|
|
||||||
|
auto NewConstraint = Constraint->substitute(Sub);
|
||||||
|
|
||||||
|
// This makes error messages prettier by relating the typing failure
|
||||||
|
// to the call site rather than the definition.
|
||||||
|
if (NewConstraint->getKind() == ConstraintKind::Equal) {
|
||||||
|
static_cast<CEqual *>(NewConstraint)->Source = Source;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctx.addConstraint(NewConstraint);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return F.Type->substitute(Sub);
|
return F.Type->substitute(Sub);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,6 +256,37 @@ namespace bolt {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Type* Checker::inferTypeExpression(TypeExpression* X, InferContext& Ctx) {
|
||||||
|
|
||||||
|
switch (X->Type) {
|
||||||
|
|
||||||
|
case NodeType::ReferenceTypeExpression:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<ReferenceTypeExpression*>(X);
|
||||||
|
auto Ty = Ctx.lookupMono(Y->Name->Name->Text);
|
||||||
|
if (Ty == nullptr) {
|
||||||
|
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name->Name);
|
||||||
|
return new TAny();
|
||||||
|
}
|
||||||
|
return Ty;
|
||||||
|
}
|
||||||
|
|
||||||
|
case NodeType::ArrowTypeExpression:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<ArrowTypeExpression*>(X);
|
||||||
|
std::vector<Type*> ParamTypes;
|
||||||
|
for (auto ParamType: Y->ParamTypes) {
|
||||||
|
ParamTypes.push_back(inferTypeExpression(ParamType, Ctx));
|
||||||
|
}
|
||||||
|
auto ReturnType = inferTypeExpression(Y->ReturnType, Ctx);
|
||||||
|
return new TArrow(ParamTypes, ReturnType);
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
ZEN_UNREACHABLE
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Type* Checker::inferExpression(Expression* X, InferContext& Ctx) {
|
Type* Checker::inferExpression(Expression* X, InferContext& Ctx) {
|
||||||
|
|
||||||
|
@ -166,10 +298,10 @@ namespace bolt {
|
||||||
Type* Ty = nullptr;
|
Type* Ty = nullptr;
|
||||||
switch (Y->Token->Type) {
|
switch (Y->Token->Type) {
|
||||||
case NodeType::IntegerLiteral:
|
case NodeType::IntegerLiteral:
|
||||||
Ty = Ctx.Env.lookupMono("Int");
|
Ty = Ctx.lookupMono("Int");
|
||||||
break;
|
break;
|
||||||
case NodeType::StringLiteral:
|
case NodeType::StringLiteral:
|
||||||
Ty = Ctx.Env.lookupMono("String");
|
Ty = Ctx.lookupMono("String");
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
ZEN_UNREACHABLE
|
ZEN_UNREACHABLE
|
||||||
|
@ -182,24 +314,37 @@ namespace bolt {
|
||||||
{
|
{
|
||||||
auto Y = static_cast<ReferenceExpression*>(X);
|
auto Y = static_cast<ReferenceExpression*>(X);
|
||||||
ZEN_ASSERT(Y->Name->ModulePath.empty());
|
ZEN_ASSERT(Y->Name->ModulePath.empty());
|
||||||
auto Scm = Ctx.Env.lookup(Y->Name->Name->Text);
|
auto Scm = Ctx.lookup(Y->Name->Name->Text);
|
||||||
if (Scm == nullptr) {
|
if (Scm == nullptr) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name);
|
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name);
|
||||||
return new TAny();
|
return new TAny();
|
||||||
}
|
}
|
||||||
return instantiate(*Scm);
|
return instantiate(*Scm, Ctx, X);
|
||||||
|
}
|
||||||
|
|
||||||
|
case NodeType::CallExpression:
|
||||||
|
{
|
||||||
|
auto Y = static_cast<CallExpression*>(X);
|
||||||
|
auto OpTy = inferExpression(Y->Function, Ctx);
|
||||||
|
auto RetType = createTypeVar(Ctx);
|
||||||
|
std::vector<Type*> ArgTypes;
|
||||||
|
for (auto Arg: Y->Args) {
|
||||||
|
ArgTypes.push_back(inferExpression(Arg, Ctx));
|
||||||
|
}
|
||||||
|
Ctx.addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X });
|
||||||
|
return RetType;
|
||||||
}
|
}
|
||||||
|
|
||||||
case NodeType::InfixExpression:
|
case NodeType::InfixExpression:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<InfixExpression*>(X);
|
auto Y = static_cast<InfixExpression*>(X);
|
||||||
auto Scm = Ctx.Env.lookup(Y->Operator->getText());
|
auto Scm = Ctx.lookup(Y->Operator->getText());
|
||||||
if (Scm == nullptr) {
|
if (Scm == nullptr) {
|
||||||
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator);
|
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator);
|
||||||
return new TAny();
|
return new TAny();
|
||||||
}
|
}
|
||||||
auto OpTy = instantiate(*Scm);
|
auto OpTy = instantiate(*Scm, Ctx, Y->Operator);
|
||||||
auto RetTy = createTypeVar();
|
auto RetTy = createTypeVar(Ctx);
|
||||||
std::vector<Type*> ArgTys;
|
std::vector<Type*> ArgTys;
|
||||||
ArgTys.push_back(inferExpression(Y->LHS, Ctx));
|
ArgTys.push_back(inferExpression(Y->LHS, Ctx));
|
||||||
ArgTys.push_back(inferExpression(Y->RHS, Ctx));
|
ArgTys.push_back(inferExpression(Y->RHS, Ctx));
|
||||||
|
@ -214,24 +359,41 @@ namespace bolt {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Checker::inferBindings(Pattern* Pattern, Type* Type, InferContext& Ctx, ConstraintSet& Constraints, TVSet& TVs) {
|
||||||
|
|
||||||
|
switch (Pattern->Type) {
|
||||||
|
|
||||||
|
case NodeType::BindPattern:
|
||||||
|
Ctx.addBinding(static_cast<BindPattern*>(Pattern)->Name->Text, Forall(TVs, Constraints, Type));
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
ZEN_UNREACHABLE
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Checker::check(SourceFile *SF) {
|
void Checker::check(SourceFile *SF) {
|
||||||
TypeEnv Global;
|
InferContext Toplevel;
|
||||||
auto StringTy = new TCon(nextConTypeId++, {}, "String");
|
auto StringTy = new TCon(nextConTypeId++, {}, "String");
|
||||||
Global.add("String", Forall(StringTy));
|
|
||||||
auto IntTy = new TCon(nextConTypeId++, {}, "Int");
|
auto IntTy = new TCon(nextConTypeId++, {}, "Int");
|
||||||
Global.add("Int", Forall(IntTy));
|
auto BoolTy = new TCon(nextConTypeId++, {}, "Bool");
|
||||||
Global.add("+", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
|
Toplevel.addBinding("String", Forall(StringTy));
|
||||||
ConstraintSet Constraints;
|
Toplevel.addBinding("Int", Forall(IntTy));
|
||||||
InferContext Toplevel { Constraints, Global };
|
Toplevel.addBinding("Bool", Forall(BoolTy));
|
||||||
|
Toplevel.addBinding("+", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
|
||||||
|
Toplevel.addBinding("-", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
|
||||||
|
Toplevel.addBinding("*", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
|
||||||
|
Toplevel.addBinding("/", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
|
||||||
infer(SF, Toplevel);
|
infer(SF, Toplevel);
|
||||||
solve(new CMany(Constraints));
|
solve(new CMany(Toplevel.Constraints));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Checker::solve(Constraint* Constraint) {
|
void Checker::solve(Constraint* Constraint) {
|
||||||
|
|
||||||
std::stack<class Constraint*> Queue;
|
std::stack<class Constraint*> Queue;
|
||||||
Queue.push(Constraint);
|
Queue.push(Constraint);
|
||||||
TVSub Sub;
|
TVSub Solution;
|
||||||
|
|
||||||
while (!Queue.empty()) {
|
while (!Queue.empty()) {
|
||||||
|
|
||||||
|
@ -256,8 +418,9 @@ namespace bolt {
|
||||||
case ConstraintKind::Equal:
|
case ConstraintKind::Equal:
|
||||||
{
|
{
|
||||||
auto Y = static_cast<CEqual*>(Constraint);
|
auto Y = static_cast<CEqual*>(Constraint);
|
||||||
if (!unify(Y->Left, Y->Right, Sub)) {
|
std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl;
|
||||||
DE.add<UnificationErrorDiagnostic>(Y->Left, Y->Right, Y->Source);
|
if (!unify(Y->Left, Y->Right, Solution)) {
|
||||||
|
DE.add<UnificationErrorDiagnostic>(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -288,6 +451,7 @@ namespace bolt {
|
||||||
auto Y = static_cast<TVar*>(A);
|
auto Y = static_cast<TVar*>(A);
|
||||||
if (B->hasTypeVar(Y)) {
|
if (B->hasTypeVar(Y)) {
|
||||||
// TODO occurs check
|
// TODO occurs check
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
Solution[Y] = B;
|
Solution[Y] = B;
|
||||||
return true;
|
return true;
|
||||||
|
@ -297,11 +461,14 @@ namespace bolt {
|
||||||
return unify(B, A, Solution);
|
return unify(B, A, Solution);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (A->getKind() == TypeKind::Any || B->getKind() == TypeKind::Any) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
if (A->getKind() == TypeKind::Arrow && B->getKind() == TypeKind::Arrow) {
|
if (A->getKind() == TypeKind::Arrow && B->getKind() == TypeKind::Arrow) {
|
||||||
auto Y = static_cast<TArrow*>(A);
|
auto Y = static_cast<TArrow*>(A);
|
||||||
auto Z = static_cast<TArrow*>(B);
|
auto Z = static_cast<TArrow*>(B);
|
||||||
if (Y->ParamTypes.size() != Z->ParamTypes.size()) {
|
if (Y->ParamTypes.size() != Z->ParamTypes.size()) {
|
||||||
// TODO diagnostic
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto Count = Y->ParamTypes.size();
|
auto Count = Y->ParamTypes.size();
|
||||||
|
@ -313,11 +480,10 @@ namespace bolt {
|
||||||
return unify(Y->ReturnType, Z->ReturnType, Solution);
|
return unify(Y->ReturnType, Z->ReturnType, Solution);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Arrow) {
|
if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) {
|
||||||
auto Y = static_cast<TCon*>(A);
|
auto Y = static_cast<TCon*>(A);
|
||||||
auto Z = static_cast<TCon*>(B);
|
auto Z = static_cast<TCon*>(B);
|
||||||
if (Y->Id != Z->Id) {
|
if (Y->Id != Z->Id) {
|
||||||
// TODO diagnostic
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ZEN_ASSERT(Y->Args.size() == Z->Args.size());
|
ZEN_ASSERT(Y->Args.size() == Z->Args.size());
|
||||||
|
@ -330,7 +496,6 @@ namespace bolt {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO diagnostic
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string describe(const Type* Ty) {
|
std::string describe(const Type* Ty) {
|
||||||
switch (Ty->getKind()) {
|
switch (Ty->getKind()) {
|
||||||
case TypeKind::Any:
|
case TypeKind::Any:
|
||||||
return "any";
|
return "any";
|
||||||
|
@ -320,10 +320,13 @@ namespace bolt {
|
||||||
case DiagnosticKind::BindingNotFound:
|
case DiagnosticKind::BindingNotFound:
|
||||||
{
|
{
|
||||||
auto E = static_cast<const BindingNotFoundDiagnostic&>(D);
|
auto E = static_cast<const BindingNotFoundDiagnostic&>(D);
|
||||||
Out << ANSI_BOLD ANSI_FG_RED "error: " ANSI_RESET "binding '" << E.Name << "' was not found\n";
|
Out << ANSI_BOLD ANSI_FG_RED "error: " ANSI_RESET "binding '" << E.Name << "' was not found\n\n";
|
||||||
//if (E.Initiator != nullptr) {
|
if (E.Initiator != nullptr) {
|
||||||
// writeExcerpt(E.Initiator->getRange());
|
auto Range = E.Initiator->getRange();
|
||||||
//}
|
//std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n";
|
||||||
|
writeExcerpt(E.Initiator->getSourceFile()->getTextFile(), Range, Range, Color::Red);
|
||||||
|
Out << "\n";
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ namespace bolt {
|
||||||
return new QualifiedName(ModulePath, static_cast<Identifier*>(Name));
|
return new QualifiedName(ModulePath, static_cast<Identifier*>(Name));
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeExpression* Parser::parseTypeExpression() {
|
TypeExpression* Parser::parsePrimitiveTypeExpression() {
|
||||||
auto T0 = Tokens.peek();
|
auto T0 = Tokens.peek();
|
||||||
switch (T0->Type) {
|
switch (T0->Type) {
|
||||||
case NodeType::Identifier:
|
case NodeType::Identifier:
|
||||||
|
@ -116,6 +116,24 @@ namespace bolt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypeExpression* Parser::parseTypeExpression() {
|
||||||
|
auto RetType = parsePrimitiveTypeExpression();
|
||||||
|
std::vector<TypeExpression*> ParamTypes;
|
||||||
|
for (;;) {
|
||||||
|
auto T1 = Tokens.peek();
|
||||||
|
if (T1->Type != NodeType::RArrow) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Tokens.get();
|
||||||
|
ParamTypes.push_back(RetType);
|
||||||
|
RetType = parsePrimitiveTypeExpression();
|
||||||
|
}
|
||||||
|
if (ParamTypes.size()) {
|
||||||
|
return new ArrowTypeExpression(ParamTypes, RetType);
|
||||||
|
}
|
||||||
|
return RetType;
|
||||||
|
}
|
||||||
|
|
||||||
Expression* Parser::parsePrimitiveExpression() {
|
Expression* Parser::parsePrimitiveExpression() {
|
||||||
auto T0 = Tokens.peek();
|
auto T0 = Tokens.peek();
|
||||||
switch (T0->Type) {
|
switch (T0->Type) {
|
||||||
|
|
|
@ -294,7 +294,9 @@ after_string_contents:
|
||||||
Text.push_back(static_cast<char>(C1));
|
Text.push_back(static_cast<char>(C1));
|
||||||
getChar();
|
getChar();
|
||||||
}
|
}
|
||||||
if (Text == "=") {
|
if (Text == "->") {
|
||||||
|
return new RArrow(StartLoc);
|
||||||
|
} else if (Text == "=") {
|
||||||
return new Equals(StartLoc);
|
return new Equals(StartLoc);
|
||||||
} else if (Text.back() == '=' && Text[Text.size()-2] != '=') {
|
} else if (Text.back() == '=' && Text[Text.size()-2] != '=') {
|
||||||
return new Assignment(Text.substr(0, Text.size()-1), StartLoc);
|
return new Assignment(Text.substr(0, Text.size()-1), StartLoc);
|
||||||
|
|
Loading…
Reference in a new issue