Implement tuples and fix bug with type vars in infer/unify algorithm
This commit is contained in:
parent
fd015dcf22
commit
508ef40bdf
8 changed files with 316 additions and 26 deletions
|
@ -76,6 +76,7 @@ namespace bolt {
|
|||
MatchCase,
|
||||
MatchExpression,
|
||||
MemberExpression,
|
||||
TupleExpression,
|
||||
NestedExpression,
|
||||
ConstantExpression,
|
||||
CallExpression,
|
||||
|
@ -830,6 +831,10 @@ namespace bolt {
|
|||
|
||||
std::string getText() const override;
|
||||
|
||||
inline Integer getInteger() const noexcept {
|
||||
return V;
|
||||
}
|
||||
|
||||
Value getValue() override;
|
||||
|
||||
static bool classof(const Node* N) {
|
||||
|
@ -1143,6 +1148,27 @@ namespace bolt {
|
|||
|
||||
};
|
||||
|
||||
class TupleExpression : public Expression {
|
||||
public:
|
||||
|
||||
class LParen* LParen;
|
||||
std::vector<std::tuple<Expression*, Comma*>> Elements;
|
||||
class RParen* RParen;
|
||||
|
||||
inline TupleExpression(
|
||||
class LParen* LParen,
|
||||
std::vector<std::tuple<Expression*, Comma*>> Elements,
|
||||
class RParen* RParen
|
||||
): Expression(NodeKind::TupleExpression),
|
||||
LParen(LParen),
|
||||
Elements(Elements),
|
||||
RParen(RParen) {}
|
||||
|
||||
Token* getFirstToken() override;
|
||||
Token* getLastToken() override;
|
||||
|
||||
};
|
||||
|
||||
class NestedExpression : public Expression {
|
||||
public:
|
||||
|
||||
|
|
|
@ -111,6 +111,8 @@ namespace bolt {
|
|||
return static_cast<D*>(this)->visitMatchExpression(static_cast<MatchExpression*>(N));
|
||||
case NodeKind::MemberExpression:
|
||||
return static_cast<D*>(this)->visitMemberExpression(static_cast<MemberExpression*>(N));
|
||||
case NodeKind::TupleExpression:
|
||||
return static_cast<D*>(this)->visitTupleExpression(static_cast<TupleExpression*>(N));
|
||||
case NodeKind::NestedExpression:
|
||||
return static_cast<D*>(this)->visitNestedExpression(static_cast<NestedExpression*>(N));
|
||||
case NodeKind::ConstantExpression:
|
||||
|
@ -378,6 +380,10 @@ namespace bolt {
|
|||
visitExpression(N);
|
||||
}
|
||||
|
||||
void visitTupleExpression(TupleExpression* N) {
|
||||
visitExpression(N);
|
||||
}
|
||||
|
||||
void visitNestedExpression(NestedExpression* N) {
|
||||
visitExpression(N);
|
||||
}
|
||||
|
@ -616,6 +622,9 @@ namespace bolt {
|
|||
case NodeKind::MemberExpression:
|
||||
visitEachChild(static_cast<MemberExpression*>(N));
|
||||
break;
|
||||
case NodeKind::TupleExpression:
|
||||
visitEachChild(static_cast<TupleExpression*>(N));
|
||||
break;
|
||||
case NodeKind::NestedExpression:
|
||||
visitEachChild(static_cast<NestedExpression*>(N));
|
||||
break;
|
||||
|
@ -876,6 +885,17 @@ namespace bolt {
|
|||
BOLT_VISIT(N->Name);
|
||||
}
|
||||
|
||||
void visitEachChild(TupleExpression* N) {
|
||||
BOLT_VISIT(N->LParen);
|
||||
for (auto [E, Comma]: N->Elements) {
|
||||
BOLT_VISIT(E);
|
||||
if (Comma) {
|
||||
BOLT_VISIT(Comma);
|
||||
}
|
||||
}
|
||||
BOLT_VISIT(N->RParen);
|
||||
}
|
||||
|
||||
void visitEachChild(NestedExpression* N) {
|
||||
BOLT_VISIT(N->LParen);
|
||||
BOLT_VISIT(N->Inner);
|
||||
|
|
|
@ -32,6 +32,7 @@ namespace bolt {
|
|||
Con,
|
||||
Arrow,
|
||||
Tuple,
|
||||
TupleIndex,
|
||||
};
|
||||
|
||||
class Type {
|
||||
|
@ -148,6 +149,21 @@ namespace bolt {
|
|||
|
||||
};
|
||||
|
||||
class TTupleIndex : public Type {
|
||||
public:
|
||||
|
||||
Type* Ty;
|
||||
std::size_t I;
|
||||
|
||||
inline TTupleIndex(Type* Ty, std::size_t I):
|
||||
Type(TypeKind::TupleIndex), Ty(Ty), I(I) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::TupleIndex;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// template<typename T>
|
||||
// struct DerefHash {
|
||||
// std::size_t operator()(const T& Value) const noexcept {
|
||||
|
@ -403,6 +419,8 @@ namespace bolt {
|
|||
|
||||
void checkTypeclassSigs(Node* N);
|
||||
|
||||
Type* simplify(Type* Ty);
|
||||
void join(TVar* A, Type* B, Node* Source);
|
||||
bool unify(Type* A, Type* B, Node* Source);
|
||||
|
||||
void solveCEqual(CEqual* C);
|
||||
|
|
|
@ -15,6 +15,7 @@ namespace bolt {
|
|||
class Type;
|
||||
class TCon;
|
||||
class TVar;
|
||||
class TTuple;
|
||||
|
||||
using TypeclassId = ByteString;
|
||||
|
||||
|
@ -37,6 +38,8 @@ namespace bolt {
|
|||
TypeclassMissing,
|
||||
InstanceNotFound,
|
||||
ClassNotFound,
|
||||
TupleIndexOutOfRange,
|
||||
InvalidTypeToTypeclass,
|
||||
};
|
||||
|
||||
class Diagnostic : std::runtime_error {
|
||||
|
@ -135,6 +138,27 @@ namespace bolt {
|
|||
|
||||
};
|
||||
|
||||
class TupleIndexOutOfRangeDiagnostic : public Diagnostic {
|
||||
public:
|
||||
|
||||
TTuple* Tuple;
|
||||
std::size_t I;
|
||||
|
||||
inline TupleIndexOutOfRangeDiagnostic(TTuple* Tuple, std::size_t I):
|
||||
Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I) {}
|
||||
|
||||
};
|
||||
|
||||
class InvalidTypeToTypeclassDiagnostic : public Diagnostic {
|
||||
public:
|
||||
|
||||
Type* Actual;
|
||||
|
||||
inline InvalidTypeToTypeclassDiagnostic(Type* Actual):
|
||||
Diagnostic(DiagnosticKind::InvalidTypeToTypeclass) {}
|
||||
|
||||
};
|
||||
|
||||
class DiagnosticEngine {
|
||||
protected:
|
||||
|
||||
|
|
|
@ -283,6 +283,14 @@ namespace bolt {
|
|||
return Name;
|
||||
}
|
||||
|
||||
Token* TupleExpression::getFirstToken() {
|
||||
return LParen;
|
||||
}
|
||||
|
||||
Token* TupleExpression::getLastToken() {
|
||||
return RParen;
|
||||
}
|
||||
|
||||
Token* NestedExpression::getFirstToken() {
|
||||
return LParen;
|
||||
}
|
||||
|
|
207
src/Checker.cc
207
src/Checker.cc
|
@ -3,6 +3,10 @@
|
|||
|
||||
// TODO (maybe) make unficiation work like union-find in find()
|
||||
|
||||
// TODO make simplify() rewrite the types in-place such that a reference too (Bool, Int).0 becomes Bool
|
||||
|
||||
// TODO Fix TVSub to use TVar.Id instead of the pointer address
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <stack>
|
||||
|
@ -58,6 +62,12 @@ namespace bolt {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Index = static_cast<TTupleIndex*>(this);
|
||||
Index->Ty->addTypeVars(TVs);
|
||||
break;
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
|
@ -93,6 +103,11 @@ namespace bolt {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Index = static_cast<TTupleIndex*>(this);
|
||||
return Index->Ty->hasTypeVar(TV);
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
|
@ -146,6 +161,12 @@ namespace bolt {
|
|||
}
|
||||
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Tuple = static_cast<TTupleIndex*>(this);
|
||||
auto NewTy = Tuple->Ty->substitute(Sub);
|
||||
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
|
@ -821,6 +842,35 @@ namespace bolt {
|
|||
return RetTy;
|
||||
}
|
||||
|
||||
case NodeKind::TupleExpression:
|
||||
{
|
||||
auto Tuple = static_cast<TupleExpression*>(X);
|
||||
std::vector<Type*> Types;
|
||||
for (auto [E, Comma]: Tuple->Elements) {
|
||||
Types.push_back(inferExpression(E));
|
||||
}
|
||||
return new TTuple(Types);
|
||||
}
|
||||
|
||||
case NodeKind::MemberExpression:
|
||||
{
|
||||
auto Member = static_cast<MemberExpression*>(X);
|
||||
switch (Member->Name->getKind()) {
|
||||
case NodeKind::IntegerLiteral:
|
||||
{
|
||||
auto I = static_cast<IntegerLiteral*>(Member->Name);
|
||||
return new TTupleIndex(inferExpression(Member->E), I->getInteger());
|
||||
}
|
||||
case NodeKind::Identifier:
|
||||
{
|
||||
// TODO
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
}
|
||||
}
|
||||
|
||||
case NodeKind::NestedExpression:
|
||||
{
|
||||
auto Nested = static_cast<NestedExpression*>(X);
|
||||
|
@ -1124,9 +1174,8 @@ namespace bolt {
|
|||
for (auto Class: Classes) {
|
||||
propagateClassTycon(Class, llvm::cast<TCon>(Ty), Source);
|
||||
}
|
||||
} else {
|
||||
ZEN_UNREACHABLE
|
||||
// DE.add<InvalidArgumentToTypeclassDiagnostic>(Ty);
|
||||
} else if (!Classes.empty()) {
|
||||
DE.add<InvalidTypeToTypeclassDiagnostic>(Ty);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1174,28 +1223,139 @@ namespace bolt {
|
|||
};
|
||||
|
||||
void Checker::solveCEqual(CEqual* C) {
|
||||
/* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */
|
||||
std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl;
|
||||
if (!unify(C->Left, C->Right, C->Source)) {
|
||||
DE.add<UnificationErrorDiagnostic>(C->Left->substitute(Solution), C->Right->substitute(Solution), C->Source);
|
||||
DE.add<UnificationErrorDiagnostic>(simplify(C->Left), simplify(C->Right), C->Source);
|
||||
}
|
||||
}
|
||||
|
||||
Type* Checker::simplify(Type* Ty) {
|
||||
|
||||
while (Ty->getKind() == TypeKind::Var) {
|
||||
auto Match = Solution.find(static_cast<TVar*>(Ty));
|
||||
if (Match == Solution.end()) {
|
||||
break;
|
||||
}
|
||||
Ty = Match->second;
|
||||
}
|
||||
|
||||
switch (Ty->getKind()) {
|
||||
|
||||
case TypeKind::Var:
|
||||
break;
|
||||
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(Ty);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewElementTypes;
|
||||
for (auto Ty: Tuple->ElementTypes) {
|
||||
auto NewElementType = simplify(Ty);
|
||||
if (NewElementType != Ty) {
|
||||
Changed = true;
|
||||
}
|
||||
NewElementTypes.push_back(NewElementType);
|
||||
}
|
||||
return Changed ? new TTuple(NewElementTypes) : Ty;
|
||||
}
|
||||
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(Ty);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewParamTys;
|
||||
for (auto ParamTy: Arrow->ParamTypes) {
|
||||
auto NewParamTy = simplify(ParamTy);
|
||||
if (NewParamTy != ParamTy) {
|
||||
Changed = true;
|
||||
}
|
||||
NewParamTys.push_back(NewParamTy);
|
||||
}
|
||||
auto NewRetTy = simplify(Arrow->ReturnType);
|
||||
if (NewRetTy != Arrow->ReturnType) {
|
||||
Changed = true;
|
||||
}
|
||||
Ty = Changed ? new TArrow(NewParamTys, NewRetTy) : Arrow;
|
||||
break;
|
||||
}
|
||||
|
||||
case TypeKind::Con:
|
||||
{
|
||||
auto Con = static_cast<TCon*>(Ty);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewArgs;
|
||||
for (auto Arg: Con->Args) {
|
||||
auto NewArg = simplify(Arg);
|
||||
if (NewArg != Arg) {
|
||||
Changed = true;
|
||||
}
|
||||
NewArgs.push_back(NewArg);
|
||||
}
|
||||
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : Ty;
|
||||
}
|
||||
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Index = static_cast<TTupleIndex*>(Ty);
|
||||
auto MaybeTuple = simplify(Index->Ty);
|
||||
if (llvm::isa<TTuple>(MaybeTuple)) {
|
||||
auto Tuple = static_cast<TTuple*>(MaybeTuple);
|
||||
if (Index->I >= Tuple->ElementTypes.size()) {
|
||||
DE.add<TupleIndexOutOfRangeDiagnostic>(Tuple, Index->I);
|
||||
} else {
|
||||
Ty = simplify(Tuple->ElementTypes[Index->I]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return Ty;
|
||||
}
|
||||
|
||||
void Checker::join(TVar* TV, Type* Ty, Node* Source) {
|
||||
|
||||
Solution[TV] = Ty;
|
||||
|
||||
propagateClasses(TV->Contexts, Ty, Source);
|
||||
|
||||
// This is a very specific adjustment that is critical to the
|
||||
// well-functioning of the infer/unify algorithm. When addConstraint() is
|
||||
// called, it may decide to solve the constraint immediately during
|
||||
// inference. If this happens, a type variable might get assigned a concrete
|
||||
// type such as Int. We therefore never want the variable to be polymorphic
|
||||
// and be instantiated with a fresh variable, as it has already been solved.
|
||||
// Should it get assigned another unification variable, that's OK too
|
||||
// because then the context of that variable is what matters and not anymore
|
||||
// the context of this one.
|
||||
if (!Contexts.empty()) {
|
||||
Contexts.back()->TVs->erase(TV);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool Checker::unify(Type* A, Type* B, Node* Source) {
|
||||
|
||||
auto find = [&](auto OrigTy) {
|
||||
auto Ty = OrigTy;
|
||||
while (Ty->getKind() == TypeKind::Var) {
|
||||
auto Match = Solution.find(static_cast<TVar*>(Ty));
|
||||
if (Match == Solution.end()) {
|
||||
break;
|
||||
}
|
||||
Ty = Match->second;
|
||||
if (llvm::isa<TVar>(Ty)) {
|
||||
auto TV = static_cast<TVar*>(Ty);
|
||||
do {
|
||||
auto Match = Solution.find(static_cast<TVar*>(Ty));
|
||||
if (Match == Solution.end()) {
|
||||
break;
|
||||
}
|
||||
Ty = Match->second;
|
||||
} while (Ty->getKind() == TypeKind::Var);
|
||||
// FIXME does this actually improove performance?
|
||||
Solution[TV] = Ty;
|
||||
}
|
||||
return Ty;
|
||||
};
|
||||
|
||||
A = find(A);
|
||||
B = find(B);
|
||||
A = simplify(A);
|
||||
B = simplify(B);
|
||||
|
||||
if (llvm::isa<TVar>(A) && llvm::isa<TVar>(B)) {
|
||||
auto Var1 = static_cast<TVar*>(A);
|
||||
|
@ -1206,19 +1366,19 @@ namespace bolt {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
TVar* Dest;
|
||||
TVar* To;
|
||||
TVar* From;
|
||||
if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) {
|
||||
Dest = Var1;
|
||||
To = Var1;
|
||||
From = Var2;
|
||||
} else {
|
||||
// Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification
|
||||
// Either way, Var1 is a good candidate for being unified away
|
||||
Dest = Var2;
|
||||
To = Var2;
|
||||
From = Var1;
|
||||
}
|
||||
Solution[From] = Dest;
|
||||
propagateClasses(From->Contexts, Dest, Source);
|
||||
join(From, To, Source);
|
||||
propagateClasses(From->Contexts, To, Source);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1234,10 +1394,7 @@ namespace bolt {
|
|||
// than obsure references to an occurs check
|
||||
return false;
|
||||
}
|
||||
Solution[TV] = B;
|
||||
if (!TV->Contexts.empty()) {
|
||||
propagateClasses(TV->Contexts, B, Source);
|
||||
}
|
||||
join(TV, B, Source);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1301,6 +1458,12 @@ namespace bolt {
|
|||
return Success;
|
||||
}
|
||||
|
||||
// if (llvm::isa<TTupleIndex>(A) && llvm::isa<TTupleIndex>(B)) {
|
||||
// auto Index1 = static_cast<TTupleIndex*>(A);
|
||||
// auto Index2 = static_cast<TTupleIndex*>(B);
|
||||
// return unify(Index1->Ty, Index2->Ty, Source);
|
||||
// }
|
||||
|
||||
if (llvm::isa<TCon>(A) && llvm::isa<TCon>(B)) {
|
||||
auto Con1 = static_cast<TCon*>(A);
|
||||
auto Con2 = static_cast<TCon*>(B);
|
||||
|
|
|
@ -162,6 +162,11 @@ namespace bolt {
|
|||
Out << ")";
|
||||
return Out.str();
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Y = static_cast<const TTupleIndex*>(Ty);
|
||||
return describe(Y->Ty) + "." + std::to_string(Y->I);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -237,9 +237,35 @@ after_constraints:
|
|||
case NodeKind::LParen:
|
||||
{
|
||||
Tokens.get();
|
||||
auto E = parseExpression();
|
||||
auto T2 = static_cast<RParen*>(expectToken(NodeKind::RParen));
|
||||
return new NestedExpression(static_cast<LParen*>(T0), E, T2);
|
||||
std::vector<std::tuple<Expression*, Comma*>> Elements;
|
||||
auto LParen = static_cast<class LParen*>(T0);
|
||||
RParen* RParen;
|
||||
for (;;) {
|
||||
auto T1 = Tokens.peek();
|
||||
if (llvm::isa<class RParen>(T1)) {
|
||||
Tokens.get();
|
||||
RParen = static_cast<class RParen*>(T1);
|
||||
break;
|
||||
}
|
||||
auto E = parseExpression();
|
||||
auto T2 = Tokens.get();
|
||||
switch (T2->getKind()) {
|
||||
case NodeKind::RParen:
|
||||
RParen = static_cast<class RParen*>(T2);
|
||||
Elements.push_back({ E, nullptr });
|
||||
goto finish;
|
||||
case NodeKind::Comma:
|
||||
Elements.push_back({ E, static_cast<class Comma*>(T2) });
|
||||
break;
|
||||
default:
|
||||
throw UnexpectedTokenDiagnostic(File, T2, { NodeKind::RParen, NodeKind::Comma });
|
||||
}
|
||||
}
|
||||
finish:
|
||||
if (Elements.size() == 1 && !std::get<1>(Elements.front())) {
|
||||
return new NestedExpression(LParen, std::get<0>(Elements.front()), RParen);
|
||||
}
|
||||
return new TupleExpression { LParen, Elements, RParen };
|
||||
}
|
||||
case NodeKind::MatchKeyword:
|
||||
{
|
||||
|
@ -307,7 +333,7 @@ finish:
|
|||
std::vector<Expression*> Args;
|
||||
for (;;) {
|
||||
auto T1 = Tokens.peek();
|
||||
if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || ExprOperators.isInfix(T1)) {
|
||||
if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || T1->getKind() == NodeKind::Comma || ExprOperators.isInfix(T1)) {
|
||||
break;
|
||||
}
|
||||
Args.push_back(parsePrimitiveExpression());
|
||||
|
|
Loading…
Reference in a new issue