Implement tuples and fix bug with type vars in infer/unify algorithm

This commit is contained in:
Sam Vervaeck 2023-05-22 17:06:31 +02:00
parent fd015dcf22
commit 508ef40bdf
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
8 changed files with 316 additions and 26 deletions

View file

@ -76,6 +76,7 @@ namespace bolt {
MatchCase,
MatchExpression,
MemberExpression,
TupleExpression,
NestedExpression,
ConstantExpression,
CallExpression,
@ -830,6 +831,10 @@ namespace bolt {
std::string getText() const override;
inline Integer getInteger() const noexcept {
return V;
}
Value getValue() override;
static bool classof(const Node* N) {
@ -1143,6 +1148,27 @@ namespace bolt {
};
class TupleExpression : public Expression {
public:
class LParen* LParen;
std::vector<std::tuple<Expression*, Comma*>> Elements;
class RParen* RParen;
inline TupleExpression(
class LParen* LParen,
std::vector<std::tuple<Expression*, Comma*>> Elements,
class RParen* RParen
): Expression(NodeKind::TupleExpression),
LParen(LParen),
Elements(Elements),
RParen(RParen) {}
Token* getFirstToken() override;
Token* getLastToken() override;
};
class NestedExpression : public Expression {
public:

View file

@ -111,6 +111,8 @@ namespace bolt {
return static_cast<D*>(this)->visitMatchExpression(static_cast<MatchExpression*>(N));
case NodeKind::MemberExpression:
return static_cast<D*>(this)->visitMemberExpression(static_cast<MemberExpression*>(N));
case NodeKind::TupleExpression:
return static_cast<D*>(this)->visitTupleExpression(static_cast<TupleExpression*>(N));
case NodeKind::NestedExpression:
return static_cast<D*>(this)->visitNestedExpression(static_cast<NestedExpression*>(N));
case NodeKind::ConstantExpression:
@ -378,6 +380,10 @@ namespace bolt {
visitExpression(N);
}
void visitTupleExpression(TupleExpression* N) {
visitExpression(N);
}
void visitNestedExpression(NestedExpression* N) {
visitExpression(N);
}
@ -616,6 +622,9 @@ namespace bolt {
case NodeKind::MemberExpression:
visitEachChild(static_cast<MemberExpression*>(N));
break;
case NodeKind::TupleExpression:
visitEachChild(static_cast<TupleExpression*>(N));
break;
case NodeKind::NestedExpression:
visitEachChild(static_cast<NestedExpression*>(N));
break;
@ -876,6 +885,17 @@ namespace bolt {
BOLT_VISIT(N->Name);
}
void visitEachChild(TupleExpression* N) {
BOLT_VISIT(N->LParen);
for (auto [E, Comma]: N->Elements) {
BOLT_VISIT(E);
if (Comma) {
BOLT_VISIT(Comma);
}
}
BOLT_VISIT(N->RParen);
}
void visitEachChild(NestedExpression* N) {
BOLT_VISIT(N->LParen);
BOLT_VISIT(N->Inner);

View file

@ -32,6 +32,7 @@ namespace bolt {
Con,
Arrow,
Tuple,
TupleIndex,
};
class Type {
@ -148,6 +149,21 @@ namespace bolt {
};
class TTupleIndex : public Type {
public:
Type* Ty;
std::size_t I;
inline TTupleIndex(Type* Ty, std::size_t I):
Type(TypeKind::TupleIndex), Ty(Ty), I(I) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::TupleIndex;
}
};
// template<typename T>
// struct DerefHash {
// std::size_t operator()(const T& Value) const noexcept {
@ -403,6 +419,8 @@ namespace bolt {
void checkTypeclassSigs(Node* N);
Type* simplify(Type* Ty);
void join(TVar* A, Type* B, Node* Source);
bool unify(Type* A, Type* B, Node* Source);
void solveCEqual(CEqual* C);

View file

@ -15,6 +15,7 @@ namespace bolt {
class Type;
class TCon;
class TVar;
class TTuple;
using TypeclassId = ByteString;
@ -37,6 +38,8 @@ namespace bolt {
TypeclassMissing,
InstanceNotFound,
ClassNotFound,
TupleIndexOutOfRange,
InvalidTypeToTypeclass,
};
class Diagnostic : std::runtime_error {
@ -135,6 +138,27 @@ namespace bolt {
};
class TupleIndexOutOfRangeDiagnostic : public Diagnostic {
public:
TTuple* Tuple;
std::size_t I;
inline TupleIndexOutOfRangeDiagnostic(TTuple* Tuple, std::size_t I):
Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I) {}
};
class InvalidTypeToTypeclassDiagnostic : public Diagnostic {
public:
Type* Actual;
inline InvalidTypeToTypeclassDiagnostic(Type* Actual):
Diagnostic(DiagnosticKind::InvalidTypeToTypeclass) {}
};
class DiagnosticEngine {
protected:

View file

@ -283,6 +283,14 @@ namespace bolt {
return Name;
}
Token* TupleExpression::getFirstToken() {
return LParen;
}
Token* TupleExpression::getLastToken() {
return RParen;
}
Token* NestedExpression::getFirstToken() {
return LParen;
}

View file

@ -3,6 +3,10 @@
// TODO (maybe) make unficiation work like union-find in find()
// TODO make simplify() rewrite the types in-place such that a reference too (Bool, Int).0 becomes Bool
// TODO Fix TVSub to use TVar.Id instead of the pointer address
#include <algorithm>
#include <iterator>
#include <stack>
@ -58,6 +62,12 @@ namespace bolt {
}
break;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
Index->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
@ -93,6 +103,11 @@ namespace bolt {
}
return false;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
return Index->Ty->hasTypeVar(TV);
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
@ -146,6 +161,12 @@ namespace bolt {
}
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
}
case TypeKind::TupleIndex:
{
auto Tuple = static_cast<TTupleIndex*>(this);
auto NewTy = Tuple->Ty->substitute(Sub);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
@ -821,6 +842,35 @@ namespace bolt {
return RetTy;
}
case NodeKind::TupleExpression:
{
auto Tuple = static_cast<TupleExpression*>(X);
std::vector<Type*> Types;
for (auto [E, Comma]: Tuple->Elements) {
Types.push_back(inferExpression(E));
}
return new TTuple(Types);
}
case NodeKind::MemberExpression:
{
auto Member = static_cast<MemberExpression*>(X);
switch (Member->Name->getKind()) {
case NodeKind::IntegerLiteral:
{
auto I = static_cast<IntegerLiteral*>(Member->Name);
return new TTupleIndex(inferExpression(Member->E), I->getInteger());
}
case NodeKind::Identifier:
{
// TODO
break;
}
default:
ZEN_UNREACHABLE
}
}
case NodeKind::NestedExpression:
{
auto Nested = static_cast<NestedExpression*>(X);
@ -1124,9 +1174,8 @@ namespace bolt {
for (auto Class: Classes) {
propagateClassTycon(Class, llvm::cast<TCon>(Ty), Source);
}
} else {
ZEN_UNREACHABLE
// DE.add<InvalidArgumentToTypeclassDiagnostic>(Ty);
} else if (!Classes.empty()) {
DE.add<InvalidTypeToTypeclassDiagnostic>(Ty);
}
};
@ -1174,16 +1223,14 @@ namespace bolt {
};
void Checker::solveCEqual(CEqual* C) {
/* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */
std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl;
if (!unify(C->Left, C->Right, C->Source)) {
DE.add<UnificationErrorDiagnostic>(C->Left->substitute(Solution), C->Right->substitute(Solution), C->Source);
DE.add<UnificationErrorDiagnostic>(simplify(C->Left), simplify(C->Right), C->Source);
}
}
bool Checker::unify(Type* A, Type* B, Node* Source) {
Type* Checker::simplify(Type* Ty) {
auto find = [&](auto OrigTy) {
auto Ty = OrigTy;
while (Ty->getKind() == TypeKind::Var) {
auto Match = Solution.find(static_cast<TVar*>(Ty));
if (Match == Solution.end()) {
@ -1191,11 +1238,124 @@ namespace bolt {
}
Ty = Match->second;
}
switch (Ty->getKind()) {
case TypeKind::Var:
break;
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(Ty);
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = simplify(Ty);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : Ty;
}
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(Ty);
bool Changed = false;
std::vector<Type*> NewParamTys;
for (auto ParamTy: Arrow->ParamTypes) {
auto NewParamTy = simplify(ParamTy);
if (NewParamTy != ParamTy) {
Changed = true;
}
NewParamTys.push_back(NewParamTy);
}
auto NewRetTy = simplify(Arrow->ReturnType);
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
Ty = Changed ? new TArrow(NewParamTys, NewRetTy) : Arrow;
break;
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(Ty);
bool Changed = false;
std::vector<Type*> NewArgs;
for (auto Arg: Con->Args) {
auto NewArg = simplify(Arg);
if (NewArg != Arg) {
Changed = true;
}
NewArgs.push_back(NewArg);
}
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : Ty;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(Ty);
auto MaybeTuple = simplify(Index->Ty);
if (llvm::isa<TTuple>(MaybeTuple)) {
auto Tuple = static_cast<TTuple*>(MaybeTuple);
if (Index->I >= Tuple->ElementTypes.size()) {
DE.add<TupleIndexOutOfRangeDiagnostic>(Tuple, Index->I);
} else {
Ty = simplify(Tuple->ElementTypes[Index->I]);
}
}
break;
}
}
return Ty;
}
void Checker::join(TVar* TV, Type* Ty, Node* Source) {
Solution[TV] = Ty;
propagateClasses(TV->Contexts, Ty, Source);
// This is a very specific adjustment that is critical to the
// well-functioning of the infer/unify algorithm. When addConstraint() is
// called, it may decide to solve the constraint immediately during
// inference. If this happens, a type variable might get assigned a concrete
// type such as Int. We therefore never want the variable to be polymorphic
// and be instantiated with a fresh variable, as it has already been solved.
// Should it get assigned another unification variable, that's OK too
// because then the context of that variable is what matters and not anymore
// the context of this one.
if (!Contexts.empty()) {
Contexts.back()->TVs->erase(TV);
}
}
bool Checker::unify(Type* A, Type* B, Node* Source) {
auto find = [&](auto OrigTy) {
auto Ty = OrigTy;
if (llvm::isa<TVar>(Ty)) {
auto TV = static_cast<TVar*>(Ty);
do {
auto Match = Solution.find(static_cast<TVar*>(Ty));
if (Match == Solution.end()) {
break;
}
Ty = Match->second;
} while (Ty->getKind() == TypeKind::Var);
// FIXME does this actually improove performance?
Solution[TV] = Ty;
}
return Ty;
};
A = find(A);
B = find(B);
A = simplify(A);
B = simplify(B);
if (llvm::isa<TVar>(A) && llvm::isa<TVar>(B)) {
auto Var1 = static_cast<TVar*>(A);
@ -1206,19 +1366,19 @@ namespace bolt {
}
return true;
}
TVar* Dest;
TVar* To;
TVar* From;
if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) {
Dest = Var1;
To = Var1;
From = Var2;
} else {
// Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification
// Either way, Var1 is a good candidate for being unified away
Dest = Var2;
To = Var2;
From = Var1;
}
Solution[From] = Dest;
propagateClasses(From->Contexts, Dest, Source);
join(From, To, Source);
propagateClasses(From->Contexts, To, Source);
return true;
}
@ -1234,10 +1394,7 @@ namespace bolt {
// than obsure references to an occurs check
return false;
}
Solution[TV] = B;
if (!TV->Contexts.empty()) {
propagateClasses(TV->Contexts, B, Source);
}
join(TV, B, Source);
return true;
}
@ -1301,6 +1458,12 @@ namespace bolt {
return Success;
}
// if (llvm::isa<TTupleIndex>(A) && llvm::isa<TTupleIndex>(B)) {
// auto Index1 = static_cast<TTupleIndex*>(A);
// auto Index2 = static_cast<TTupleIndex*>(B);
// return unify(Index1->Ty, Index2->Ty, Source);
// }
if (llvm::isa<TCon>(A) && llvm::isa<TCon>(B)) {
auto Con1 = static_cast<TCon*>(A);
auto Con2 = static_cast<TCon*>(B);

View file

@ -162,6 +162,11 @@ namespace bolt {
Out << ")";
return Out.str();
}
case TypeKind::TupleIndex:
{
auto Y = static_cast<const TTupleIndex*>(Ty);
return describe(Y->Ty) + "." + std::to_string(Y->I);
}
}
}

View file

@ -237,9 +237,35 @@ after_constraints:
case NodeKind::LParen:
{
Tokens.get();
std::vector<std::tuple<Expression*, Comma*>> Elements;
auto LParen = static_cast<class LParen*>(T0);
RParen* RParen;
for (;;) {
auto T1 = Tokens.peek();
if (llvm::isa<class RParen>(T1)) {
Tokens.get();
RParen = static_cast<class RParen*>(T1);
break;
}
auto E = parseExpression();
auto T2 = static_cast<RParen*>(expectToken(NodeKind::RParen));
return new NestedExpression(static_cast<LParen*>(T0), E, T2);
auto T2 = Tokens.get();
switch (T2->getKind()) {
case NodeKind::RParen:
RParen = static_cast<class RParen*>(T2);
Elements.push_back({ E, nullptr });
goto finish;
case NodeKind::Comma:
Elements.push_back({ E, static_cast<class Comma*>(T2) });
break;
default:
throw UnexpectedTokenDiagnostic(File, T2, { NodeKind::RParen, NodeKind::Comma });
}
}
finish:
if (Elements.size() == 1 && !std::get<1>(Elements.front())) {
return new NestedExpression(LParen, std::get<0>(Elements.front()), RParen);
}
return new TupleExpression { LParen, Elements, RParen };
}
case NodeKind::MatchKeyword:
{
@ -307,7 +333,7 @@ finish:
std::vector<Expression*> Args;
for (;;) {
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || ExprOperators.isInfix(T1)) {
if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || T1->getKind() == NodeKind::Comma || ExprOperators.isInfix(T1)) {
break;
}
Args.push_back(parsePrimitiveExpression());