Improve support for type classes and simplify algorithm

This commit is contained in:
Sam Vervaeck 2023-06-01 16:22:37 +02:00
parent b467f9e644
commit 693301bac6
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
6 changed files with 138 additions and 210 deletions

View file

@ -1765,6 +1765,10 @@ namespace bolt {
return Parent->getKind() == NodeKind::ClassDeclaration;
}
bool isSignature() const noexcept {
return !Body;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;

View file

@ -109,9 +109,9 @@ namespace bolt {
public:
TypeclassSignature Sig;
FunctionDeclaration* Decl;
Node* Decl;
inline TypeclassMissingDiagnostic(TypeclassSignature Sig, FunctionDeclaration* Decl):
inline TypeclassMissingDiagnostic(TypeclassSignature Sig, Node* Decl):
Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {}
inline Node* getNode() const override {
@ -124,10 +124,10 @@ namespace bolt {
public:
ByteString TypeclassName;
TCon* Ty;
Type* Ty;
Node* Source;
inline InstanceNotFoundDiagnostic(ByteString TypeclassName, TCon* Ty, Node* Source):
inline InstanceNotFoundDiagnostic(ByteString TypeclassName, Type* Ty, Node* Source):
Diagnostic(DiagnosticKind::InstanceNotFound), TypeclassName(TypeclassName), Ty(Ty), Source(Source) {}
inline Node* getNode() const override {

View file

@ -13,6 +13,7 @@ namespace bolt {
class Type;
class TVar;
class TCon;
using TypeclassId = ByteString;
@ -29,6 +30,12 @@ namespace bolt {
};
struct TypeSig {
Type* Orig;
Type* Op;
std::vector<Type*> Args;
};
enum class TypeIndexKind {
AppOpType,
AppArgType,
@ -267,6 +274,10 @@ namespace bolt {
return VK;
}
inline bool isRigid() const noexcept {
return VK == VarKind::Rigid;
}
Type* find();
void set(Type* Ty);
@ -282,6 +293,8 @@ namespace bolt {
ByteString Name;
TypeclassContext Provided;
inline TVarRigid(size_t Id, ByteString Name):
TVar(Id, VarKind::Rigid), Name(Name) {}
@ -405,44 +418,48 @@ namespace bolt {
virtual void enterType(C<Type>* Ty) {}
virtual void exitType(C<Type>* Ty) {}
virtual void visitVarType(C<TVar>* Ty) {
virtual void visitType(C<Type>* Ty) {
visitEachChild(Ty);
}
virtual void visitVarType(C<TVar>* Ty) {
visitType(Ty);
}
virtual void visitAppType(C<TApp>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
virtual void visitPresentType(C<TPresent>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
virtual void visitConType(C<TCon>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
virtual void visitArrowType(C<TArrow>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
virtual void visitTupleType(C<TTuple>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
virtual void visitTupleIndexType(C<TTupleIndex>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
virtual void visitAbsentType(C<TAbsent>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
virtual void visitFieldType(C<TField>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
virtual void visitNilType(C<TNil>* Ty) {
visitEachChild(Ty);
visitType(Ty);
}
public:

View file

@ -20,6 +20,8 @@
// TODO create the constraint in addConstraint, not the other way round
// TODO Find a way to create a match expression in between ( and )
#include <algorithm>
#include <iterator>
#include <stack>
@ -271,11 +273,11 @@ namespace bolt {
case NodeKind::ClassDeclaration:
{
auto Class = static_cast<ClassDeclaration*>(X);
for (auto TE: Class->TypeVars) {
auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText());
TV->Contexts.emplace(Class->Name->getCanonicalText());
TE->setType(TV);
}
// for (auto TE: Class->TypeVars) {
// auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText());
// // TV->Contexts.emplace(Class->Name->getCanonicalText());
// TE->setType(TV);
// }
for (auto Element: Class->Elements) {
forwardDeclare(Element);
}
@ -459,17 +461,26 @@ namespace bolt {
setContext(Let->Ctx);
auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) {
auto Id = Class->Name->getCanonicalText();
auto Ctx = &getContext();
std::vector<TVar*> Out;
for (auto TE: Class->TypeVars) {
auto Name = TE->Name->getCanonicalText();
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
TV->Contexts.emplace(Id);
Ctx->Env.emplace(Name, new Forall(TV));
Out.push_back(TV);
}
return Out;
};
// If declaring a let-declaration inside a type class declaration,
// we need to mark that the let-declaration requires this class.
// This marking is set on the rigid type variables of the class, which
// are then added to this local type environment.
if (Let->isClass()) {
auto Class = static_cast<ClassDeclaration*>(Let->Parent);
for (auto TE: Class->TypeVars) {
auto TV = llvm::cast<TVar>(TE->getType());
Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV));
Let->Ctx->TVs->emplace(TV);
}
addClassVars(static_cast<ClassDeclaration*>(Let->Parent), true);
}
// Here we infer the primary type of the let declaration. If there's a
@ -492,27 +503,22 @@ namespace bolt {
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
auto Class = llvm::cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
auto SigLet = llvm::cast<FunctionDeclaration>(Class->getScope()->lookupDirect({ {}, Let->Name->getCanonicalText() }, SymbolKind::Var));
auto Params = addClassVars(Class, false);
// The type asserts in the type class declaration might make use of
// the type parameters of the type class declaration, so it is
// important to make them available in the type environment. Moreover,
// we will be unifying them with the actual types declared in the
// instance declaration, so we keep track of them.
std::vector<TVar *> Params;
TVSub Sub;
for (auto TE: Class->TypeVars) {
auto TV = createTypeVar();
Sub.emplace(llvm::cast<TVar>(TE->getType()), TV);
Params.push_back(TV);
}
auto SigLet = llvm::cast<FunctionDeclaration>(Class->getScope()->lookupDirect({ {}, Let->Name->getCanonicalText() }, SymbolKind::Var));
// It would be very strange if there was no type assert in the type
// class let-declaration but we rather not let the compiler crash if that happens.
if (SigLet->TypeAssert) {
addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression)->substitute(Sub), Let));
}
// std::vector<TVar *> Params;
// TVSub Sub;
// for (auto TE: Class->TypeVars) {
// auto TV = createTypeVar();
// Sub.emplace(llvm::cast<TVar>(TE->getType()), TV);
// Params.push_back(TV);
// }
// Here we do the actual unification of e.g. Eq a with Eq Bool. The
// unification variables we created previously will be unified with
@ -522,6 +528,16 @@ namespace bolt {
addConstraint(new CEqual(Param, TE->getType(), TE));
}
// It would be very strange if there was no type assert in the type
// class let-declaration but we rather not let the compiler crash if that happens.
if (SigLet->TypeAssert) {
// Note that we can't do SigLet->TypeAssert->TypeExpression->getType()
// because we need to re-generate the type within the local context of
// this let-declaration.
// TODO make CEqual accept multiple nodes
addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let));
}
}
if (Let->Body) {
@ -542,12 +558,18 @@ namespace bolt {
}
}
Let->Ctx->Parent->Env.emplace(Let->Name->getCanonicalText(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty));
if (!Let->isInstance()) {
Let->Ctx->Parent->Env.emplace(Let->Name->getCanonicalText(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty));
}
}
void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) {
if (Decl->isSignature()) {
return;
}
setContext(Decl->Ctx);
std::vector<Type*> ParamTypes;
@ -770,7 +792,9 @@ namespace bolt {
auto D = static_cast<TypeclassConstraintExpression*>(C);
std::vector<Type*> Types;
for (auto TE: D->TEs) {
Types.push_back(inferTypeExpression(TE));
auto TV = static_cast<TVarRigid*>(inferTypeExpression(TE));
TV->Provided.emplace(D->Name->getCanonicalText());
Types.push_back(TV);
}
return new CClass(D->Name->getCanonicalText(), Types);
}
@ -1175,136 +1199,6 @@ namespace bolt {
}
void Checker::checkTypeclassSigs(Node* N) {
struct LetVisitor : CSTVisitor<LetVisitor> {
Checker& C;
void visitLetDeclaration(FunctionDeclaration* Decl) {
// Only inspect those let-declarations that look like a function
if (Decl->Params.empty()) {
return;
}
// Will contain the type classes that were specified in the type assertion by the user.
// There might be some other signatures as well, but those are an implementation detail.
std::vector<TypeclassSignature> Expected;
// We must add the type class itself to Expected because in order for
// propagation to work the rigid type variables expect this class to be
// present even inside the current class. By adding it to Expected, we
// are effectively cancelling out the default behavior of requiring the
// presence of this type classes.
if (llvm::isa<ClassDeclaration>(Decl->Parent)) {
auto Class = llvm::cast<ClassDeclaration>(Decl->Parent);
std::vector<TVar *> Tys;
for (auto TE : Class->TypeVars) {
Tys.push_back(llvm::cast<TVar>(TE->getType()));
}
Expected.push_back(
TypeclassSignature{Class->Name->getCanonicalText(), Tys});
}
// Here we scan the type signature for type classes that user expects to be there.
if (Decl->TypeAssert != nullptr) {
if (llvm::isa<QualifiedTypeExpression>(Decl->TypeAssert->TypeExpression)) {
auto QTE = static_cast<QualifiedTypeExpression*>(Decl->TypeAssert->TypeExpression);
for (auto [C, Comma]: QTE->Constraints) {
if (llvm::isa<TypeclassConstraintExpression>(C)) {
auto TCE = static_cast<TypeclassConstraintExpression*>(C);
std::vector<TVar*> Tys;
for (auto TE: TCE->TEs) {
auto TV = TE->getType();
ZEN_ASSERT(llvm::isa<TVar>(TV));
Tys.push_back(static_cast<TVar*>(TV));
}
Expected.push_back(TypeclassSignature { TCE->Name->getCanonicalText(), Tys });
}
}
}
}
// Sort them lexically and remove any duplicates
std::sort(Expected.begin(), Expected.end());
Expected.erase(std::unique(Expected.begin(), Expected.end()), Expected.end());
// Will contain the type class signatures that our program inferred that
// at the very least should be present to make the body work.
std::vector<TypeclassSignature> Actual;
// This is ugly but it works. Scan all type variables local to this
// declaration and add the classes that they require to Actual.
for (auto Ty: *Decl->Ctx->TVs) {
auto S = Ty->solve();
if (llvm::isa<TVar>(S)) {
auto TV = static_cast<TVar*>(S);
for (auto Class: TV->Contexts) {
Actual.push_back(TypeclassSignature { Class, { TV } });
}
}
}
// Sort them lexically and remove any duplicates
std::sort(Actual.begin(), Actual.end());
Actual.erase(std::unique(Actual.begin(), Actual.end()), Actual.end());
auto ActualIter = Actual.begin();
auto ExpectedIter = Expected.begin();
for (; ActualIter != Actual.end() || ExpectedIter != Expected.end() ;) {
// Our program inferred no more type classes that should be present,
// yet Expected still did find a few that the user declared in a
// signature. No errors should be reported, and we can quit this loop.
if (ActualIter == Actual.end()) {
// TODO Maybe issue a warning that a type class went unused
break;
}
// There are no more type classes that were expected, so any remaining
// type classes in Actual will not have a corresponding signature.
// This should be reported as an error.
if (ExpectedIter == Expected.end()) {
for (; ActualIter != Actual.end(); ActualIter++) {
C.DE.add<TypeclassMissingDiagnostic>(*ActualIter, Decl);
}
break;
}
// If ExpectedIter is already at Show, but ActualIter is still at Eq,
// then we clearly missed the Eq in ExpectedIter. This clearly is an
// error, since the user missed something in a type signature.
if (*ActualIter < *ExpectedIter) {
C.DE.add<TypeclassMissingDiagnostic>(*ActualIter, Decl);
ActualIter++;
continue;
}
// If ActualIter is Show but ExpectedIter is still Eq, then the user
// specified too much type classes in a type signature. This is no error,
// but it might be worthwhile to issue a warning.
if (*ExpectedIter < *ActualIter) {
// DE.add<TypeclassMissingDiagnostic>(It2->Name, Decl);
ExpectedIter++;
continue;
}
// Both type class signatures are equal, cancelling each other out.
ActualIter++;
ExpectedIter++;
}
}
};
LetVisitor V { {}, *this };
V.visit(N);
}
Type* Checker::getType(TypedNode *Node) {
return Node->getType()->solve();
}
@ -1368,7 +1262,6 @@ namespace bolt {
ActiveContext = nullptr;
solve(new CMany(*SF->Ctx->Constraints));
checkTypeclassSigs(SF);
}
void Checker::solve(Constraint* Constraint) {
@ -1419,15 +1312,9 @@ namespace bolt {
if (Con1->Id != Con2-> Id) {
return false;
}
// TODO must handle a TApp
// ZEN_ASSERT(Con1->Args.size() == Con2->Args.size());
// for (auto [T1, T2]: zen::zip(Con1->Args, Con2->Args)) {
// if (!assignableTo(T1, T2)) {
// return false;
// }
// }
return true;
}
// TODO must handle a TApp
ZEN_UNREACHABLE
}
@ -1504,52 +1391,78 @@ namespace bolt {
return unify(Constraint->Left, Constraint->Right, false);
}
std::vector<TypeclassContext> findInstanceContext(TCon* Ty, TypeclassId& Class) {
std::vector<TypeclassContext> findInstanceContext(const TypeSig& Ty, TypeclassId& Class) {
auto Match = C.InstanceMap.find(Class);
std::vector<TypeclassContext> S;
if (Match != C.InstanceMap.end()) {
for (auto Instance: Match->second) {
if (assignableTo(Ty, Instance->TypeExps[0]->getType())) {
if (assignableTo(Ty.Orig, Instance->TypeExps[0]->getType())) {
std::vector<TypeclassContext> S;
// TODO handle TApp
// for (auto Arg: Ty->Args) {
// TypeclassContext Classes;
// // TODO
// S.push_back(Classes);
// }
for (auto Arg: Ty.Args) {
TypeclassContext Classes;
// TODO
S.push_back(Classes);
}
return S;
}
}
}
C.DE.add<InstanceNotFoundDiagnostic>(Class, Ty, getSource());
// TODO handle TApp
// for (auto Arg: Ty->Args) {
// S.push_back({});
// }
C.DE.add<InstanceNotFoundDiagnostic>(Class, Ty.Orig, getSource());
for (auto Arg: Ty.Args) {
S.push_back({});
}
return S;
}
TypeSig getTypeSig(Type* Ty) {
struct Visitor : TypeVisitor {
Type* Op = nullptr;
std::vector<Type*> Args;
void visitType(Type* Ty) override {
if (!Op) {
Op = Ty;
} else {
Args.push_back(Ty);
}
}
void visitAppType(TApp* Ty) override {
visitEachChild(Ty);
}
};
Visitor V;
V.visit(Ty);
return TypeSig { Ty, V.Op, V.Args };
}
void propagateClasses(std::unordered_set<TypeclassId>& Classes, Type* Ty) {
if (llvm::isa<TVar>(Ty)) {
auto TV = llvm::cast<TVar>(Ty);
for (auto Class: Classes) {
TV->Contexts.emplace(Class);
}
} else if (llvm::isa<TCon>(Ty)) {
if (TV->isRigid()) {
auto RV = static_cast<TVarRigid*>(Ty);
for (auto Id: RV->Contexts) {
if (!RV->Provided.count(Id)) {
C.DE.add<TypeclassMissingDiagnostic>(TypeclassSignature { Id, { RV } }, getSource());
}
}
}
} else if (llvm::isa<TCon>(Ty) || llvm::isa<TApp>(Ty)) {
auto Sig = getTypeSig(Ty);
for (auto Class: Classes) {
propagateClassTycon(Class, llvm::cast<TCon>(Ty));
propagateClassTycon(Class, Sig);
}
} else if (!Classes.empty()) {
C.DE.add<InvalidTypeToTypeclassDiagnostic>(Ty, std::vector(Classes.begin(), Classes.end()), getSource());
}
};
void propagateClassTycon(TypeclassId& Class, TCon* Ty) {
auto S = findInstanceContext(Ty, Class);
// TODO handle TApp
// for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) {
// propagateClasses(Classes, Arg);
// }
void propagateClassTycon(TypeclassId& Class, const TypeSig& Sig) {
auto S = findInstanceContext(Sig, Class);
for (auto [Classes, Arg]: zen::zip(S, Sig.Args)) {
propagateClasses(Classes, Arg);
}
};
/**

View file

@ -110,6 +110,8 @@ namespace bolt {
return "'return'";
case NodeKind::TypeKeyword:
return "'type'";
case NodeKind::ReferenceTypeExpression:
return "a type reference";
case NodeKind::FunctionDeclaration:
return "a function declaration";
case NodeKind::VariableDeclaration:

View file

@ -1029,7 +1029,6 @@ finish:
PubKeyword* Pub = nullptr;
FnKeyword* Fn;
MutKeyword* Mut = nullptr;
TypeAssert* TA = nullptr;
LetBody* Body = nullptr;
@ -1047,11 +1046,6 @@ finish:
return nullptr;
}
Fn = static_cast<FnKeyword*>(T0);
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::MutKeyword) {
Mut = static_cast<MutKeyword*>(T1);
Tokens.get();
}
auto Name = expectToken<Identifier>();
if (!Name) {
@ -1059,9 +1053,6 @@ finish:
Pub->unref();
}
Fn->unref();
if (Mut) {
Mut->unref();
}
skipToLineFoldEnd();
return nullptr;
}
@ -1146,6 +1137,7 @@ after_params:
checkLineFoldEnd();
finish:
return new FunctionDeclaration(
Pub,
Fn,