Split up Checker.hpp and make room for better type mismatch errors

This commit is contained in:
Sam Vervaeck 2023-05-22 22:37:58 +02:00
parent 508ef40bdf
commit 302823ac9b
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
7 changed files with 821 additions and 428 deletions

View file

@ -20,6 +20,7 @@ add_library(
src/Diagnostics.cc
src/Scanner.cc
src/Parser.cc
src/Types.cc
src/Checker.cc
src/IPRGraph.cc
)

View file

@ -6,170 +6,15 @@
#include "bolt/ByteString.hpp"
#include "bolt/Common.hpp"
#include "bolt/CST.hpp"
#include "bolt/Diagnostics.hpp"
#include "bolt/Type.hpp"
#include <istream>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <optional>
namespace bolt {
class DiagnosticEngine;
class Node;
class Type;
class TVar;
using TVSub = std::unordered_map<TVar*, Type*>;
using TVSet = std::unordered_set<TVar*>;
using TypeclassContext = std::unordered_set<TypeclassId>;
enum class TypeKind : unsigned char {
Var,
Con,
Arrow,
Tuple,
TupleIndex,
};
class Type {
const TypeKind Kind;
protected:
inline Type(TypeKind Kind):
Kind(Kind) {}
public:
bool hasTypeVar(const TVar* TV);
void addTypeVars(TVSet& TVs);
inline TVSet getTypeVars() {
TVSet Out;
addTypeVars(Out);
return Out;
}
Type* substitute(const TVSub& Sub);
inline TypeKind getKind() const noexcept {
return Kind;
}
};
class TCon : public Type {
public:
const size_t Id;
std::vector<Type*> Args;
ByteString DisplayName;
inline TCon(const size_t Id, std::vector<Type*> Args, ByteString DisplayName):
Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Con;
}
};
enum class VarKind {
Rigid,
Unification,
};
class TVar : public Type {
public:
const size_t Id;
VarKind VK;
TypeclassContext Contexts;
inline TVar(size_t Id, VarKind VK):
Type(TypeKind::Var), Id(Id), VK(VK) {}
inline VarKind getVarKind() const noexcept {
return VK;
}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Var;
}
};
class TVarRigid : public TVar {
public:
ByteString Name;
inline TVarRigid(size_t Id, ByteString Name):
TVar(Id, VarKind::Rigid), Name(Name) {}
};
class TArrow : public Type {
public:
std::vector<Type*> ParamTypes;
Type* ReturnType;
inline TArrow(
std::vector<Type*> ParamTypes,
Type* ReturnType
): Type(TypeKind::Arrow),
ParamTypes(ParamTypes),
ReturnType(ReturnType) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Arrow;
}
};
class TTuple : public Type {
public:
std::vector<Type*> ElementTypes;
inline TTuple(std::vector<Type*> ElementTypes):
Type(TypeKind::Tuple), ElementTypes(ElementTypes) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Tuple;
}
};
class TTupleIndex : public Type {
public:
Type* Ty;
std::size_t I;
inline TTupleIndex(Type* Ty, std::size_t I):
Type(TypeKind::TupleIndex), Ty(Ty), I(I) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::TupleIndex;
}
};
// template<typename T>
// struct DerefHash {
// std::size_t operator()(const T& Value) const noexcept {
// return std::hash<decltype(*Value)>{}(*Value);
// }
// };
class Constraint;
@ -354,6 +199,8 @@ namespace bolt {
std::unordered_map<Node*, InferContext*> CallGraph;
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
Type* BoolType;
Type* IntType;
Type* StringType;
@ -412,20 +259,43 @@ namespace bolt {
Type* instantiate(Scheme* S, Node* Source);
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
std::vector<TypeclassContext> findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source);
void propagateClasses(TypeclassContext& Classes, Type* Ty, Node* Source);
void propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source);
void checkTypeclassSigs(Node* N);
std::vector<TypeclassContext> findInstanceContext(TCon* Ty, TypeclassId& Class);
void propagateClasses(TypeclassContext& Classes, Type* Ty);
void propagateClassTycon(TypeclassId& Class, TCon* Ty);
Type* simplify(Type* Ty);
void join(TVar* A, Type* B, Node* Source);
bool unify(Type* A, Type* B, Node* Source);
/**
* Assign a type to a unification variable.
*
* If there are class constraints, those are propagated.
*
* If this type variable is solved during inference, it will be removed from
* the inference context.
*
* Other side effects may occur.
*/
void join(TVar* A, Type* B);
Type* OrigLeft;
Type* OrigRight;
TypePath LeftPath;
TypePath RightPath;
Node* Source;
bool unify(Type* A, Type* B);
void unifyError();
void solveCEqual(CEqual* C);
void solve(Constraint* Constraint, TVSub& Solution);
/**
* Verifies that type class signatures on type asserts in let-declarations
* correctly declare the right type classes.
*/
void checkTypeclassSigs(Node* N);
public:
Checker(const LanguageConfig& Config, DiagnosticEngine& DE);

View file

@ -9,27 +9,10 @@
#include "bolt/ByteString.hpp"
#include "bolt/String.hpp"
#include "bolt/CST.hpp"
#include "bolt/Type.hpp"
namespace bolt {
class Type;
class TCon;
class TVar;
class TTuple;
using TypeclassId = ByteString;
struct TypeclassSignature {
using TypeclassId = ByteString;
TypeclassId Id;
std::vector<TVar*> Params;
bool operator<(const TypeclassSignature& Other) const;
bool operator==(const TypeclassSignature& Other) const;
};
enum class DiagnosticKind : unsigned char {
UnexpectedToken,
UnexpectedString,
@ -95,13 +78,15 @@ namespace bolt {
class UnificationErrorDiagnostic : public Diagnostic {
public:
Type* Left;
Type* Right;
TypePath LeftPath;
TypePath RightPath;
Node* Source;
inline UnificationErrorDiagnostic(Type* Left, Type* Right, Node* Source):
Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), Source(Source) {}
inline UnificationErrorDiagnostic(Type* Left, Type* Right, TypePath LeftPath, TypePath RightPath, Node* Source):
Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {}
};

282
include/bolt/Type.hpp Normal file
View file

@ -0,0 +1,282 @@
#pragma once
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "bolt/ByteString.hpp"
namespace bolt {
class Type;
class TVar;
using TypeclassId = ByteString;
using TypeclassContext = std::unordered_set<TypeclassId>;
struct TypeclassSignature {
using TypeclassId = ByteString;
TypeclassId Id;
std::vector<TVar*> Params;
bool operator<(const TypeclassSignature& Other) const;
bool operator==(const TypeclassSignature& Other) const;
};
enum class TypeIndexKind {
ArrowParamType,
ArrowReturnType,
ConArg,
TupleElement,
End,
};
class TypeIndex {
protected:
friend class Type;
friend class TypeIterator;
TypeIndexKind Kind;
union {
std::size_t I;
};
TypeIndex(TypeIndexKind Kind):
Kind(Kind) {}
TypeIndex(TypeIndexKind Kind, std::size_t I):
Kind(Kind), I(I) {}
public:
bool operator==(const TypeIndex& Other) const noexcept;
void advance(const Type* Ty);
static TypeIndex forArrowReturnType() {
return { TypeIndexKind::ArrowReturnType };
}
static TypeIndex forArrowParamType(std::size_t I) {
return { TypeIndexKind::ArrowParamType, I };
}
static TypeIndex forConArg(std::size_t I) {
return { TypeIndexKind::ConArg, I };
}
static TypeIndex forTupleElement(std::size_t I) {
return { TypeIndexKind::TupleElement, I };
}
};
class TypeIterator {
friend class Type;
Type* Ty;
TypeIndex Index;
TypeIterator(Type* Ty, TypeIndex Index):
Ty(Ty), Index(Index) {}
public:
TypeIterator& operator++() noexcept {
Index.advance(Ty);
return *this;
}
bool operator==(const TypeIterator& Other) const noexcept {
return Ty == Other.Ty && Index == Other.Index;
}
Type* operator*() {
return Ty;
}
TypeIndex getIndex() const noexcept {
return Index;
}
};
using TypePath = std::vector<TypeIndex>;
using TVSub = std::unordered_map<TVar*, Type*>;
using TVSet = std::unordered_set<TVar*>;
enum class TypeKind : unsigned char {
Var,
Con,
Arrow,
Tuple,
TupleIndex,
};
class Type {
const TypeKind Kind;
protected:
inline Type(TypeKind Kind):
Kind(Kind) {}
public:
inline TypeKind getKind() const noexcept {
return Kind;
}
bool hasTypeVar(const TVar* TV);
void addTypeVars(TVSet& TVs);
inline TVSet getTypeVars() {
TVSet Out;
addTypeVars(Out);
return Out;
}
Type* substitute(const TVSub& Sub);
TypeIterator begin();
TypeIterator end();
TypeIndex getStartIndex();
TypeIndex getEndIndex();
Type* resolve(const TypeIndex& Index) const noexcept;
Type* resolve(const TypePath& Path) noexcept {
Type* Ty = this;
for (auto El: Path) {
Ty = Ty->resolve(El);
}
return Ty;
}
bool operator==(const Type& Other) const noexcept;
bool operator!=(const Type& Other) const noexcept {
return !(*this == Other);
}
};
class TCon : public Type {
public:
const size_t Id;
std::vector<Type*> Args;
ByteString DisplayName;
inline TCon(const size_t Id, std::vector<Type*> Args, ByteString DisplayName):
Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Con;
}
};
enum class VarKind {
Rigid,
Unification,
};
class TVar : public Type {
public:
const size_t Id;
VarKind VK;
TypeclassContext Contexts;
inline TVar(size_t Id, VarKind VK):
Type(TypeKind::Var), Id(Id), VK(VK) {}
inline VarKind getVarKind() const noexcept {
return VK;
}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Var;
}
};
class TVarRigid : public TVar {
public:
ByteString Name;
inline TVarRigid(size_t Id, ByteString Name):
TVar(Id, VarKind::Rigid), Name(Name) {}
};
class TArrow : public Type {
public:
std::vector<Type*> ParamTypes;
Type* ReturnType;
inline TArrow(
std::vector<Type*> ParamTypes,
Type* ReturnType
): Type(TypeKind::Arrow),
ParamTypes(ParamTypes),
ReturnType(ReturnType) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Arrow;
}
};
class TTuple : public Type {
public:
std::vector<Type*> ElementTypes;
inline TTuple(std::vector<Type*> ElementTypes):
Type(TypeKind::Tuple), ElementTypes(ElementTypes) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Tuple;
}
};
class TTupleIndex : public Type {
public:
Type* Ty;
std::size_t I;
inline TTupleIndex(Type* Ty, std::size_t I):
Type(TypeKind::TupleIndex), Ty(Ty), I(I) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::TupleIndex;
}
};
// template<typename T>
// struct DerefHash {
// std::size_t operator()(const T& Value) const noexcept {
// return std::hash<decltype(*Value)>{}(*Value);
// }
// };
}

View file

@ -25,165 +25,6 @@ namespace bolt {
std::string describe(const Type* Ty);
bool TypeclassSignature::operator<(const TypeclassSignature& Other) const {
if (Id < Other.Id) {
return true;
}
ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1);
return Params[0]->Id < Other.Params[0]->Id;
}
bool TypeclassSignature::operator==(const TypeclassSignature& Other) const {
ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1);
return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id;
}
void Type::addTypeVars(TVSet& TVs) {
switch (Kind) {
case TypeKind::Var:
TVs.emplace(static_cast<TVar*>(this));
break;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
for (auto Ty: Arrow->ParamTypes) {
Ty->addTypeVars(TVs);
}
Arrow->ReturnType->addTypeVars(TVs);
break;
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
for (auto Ty: Con->Args) {
Ty->addTypeVars(TVs);
}
break;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
Index->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Tuple->ElementTypes) {
Ty->addTypeVars(TVs);
}
break;
}
}
}
bool Type::hasTypeVar(const TVar* TV) {
switch (Kind) {
case TypeKind::Var:
return static_cast<TVar*>(this)->Id == TV->Id;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
for (auto Ty: Arrow->ParamTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return Arrow->ReturnType->hasTypeVar(TV);
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
for (auto Ty: Con->Args) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
return Index->Ty->hasTypeVar(TV);
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Tuple->ElementTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
}
}
Type* Type::substitute(const TVSub &Sub) {
switch (Kind) {
case TypeKind::Var:
{
auto TV = static_cast<TVar*>(this);
auto Match = Sub.find(TV);
return Match != Sub.end() ? Match->second->substitute(Sub) : this;
}
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
bool Changed = false;
std::vector<Type*> NewParamTypes;
for (auto Ty: Arrow->ParamTypes) {
auto NewParamType = Ty->substitute(Sub);
if (NewParamType != Ty) {
Changed = true;
}
NewParamTypes.push_back(NewParamType);
}
auto NewRetTy = Arrow->ReturnType->substitute(Sub) ;
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
bool Changed = false;
std::vector<Type*> NewArgs;
for (auto Arg: Con->Args) {
auto NewArg = Arg->substitute(Sub);
if (NewArg != Arg) {
Changed = true;
}
NewArgs.push_back(NewArg);
}
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
}
case TypeKind::TupleIndex:
{
auto Tuple = static_cast<TTupleIndex*>(this);
auto NewTy = Tuple->Ty->substitute(Sub);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = Ty->substitute(Sub);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : this;
}
}
}
Constraint* Constraint::substitute(const TVSub &Sub) {
switch (Kind) {
case ConstraintKind::Class:
@ -1141,7 +982,7 @@ namespace bolt {
ZEN_UNREACHABLE
}
std::vector<TypeclassContext> Checker::findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source) {
std::vector<TypeclassContext> Checker::findInstanceContext(TCon* Ty, TypeclassId& Class) {
auto Match = InstanceMap.find(Class);
std::vector<TypeclassContext> S;
if (Match != InstanceMap.end()) {
@ -1164,7 +1005,7 @@ namespace bolt {
return S;
}
void Checker::propagateClasses(std::unordered_set<TypeclassId>& Classes, Type* Ty, Node* Source) {
void Checker::propagateClasses(std::unordered_set<TypeclassId>& Classes, Type* Ty) {
if (llvm::isa<TVar>(Ty)) {
auto TV = llvm::cast<TVar>(Ty);
for (auto Class: Classes) {
@ -1172,61 +1013,29 @@ namespace bolt {
}
} else if (llvm::isa<TCon>(Ty)) {
for (auto Class: Classes) {
propagateClassTycon(Class, llvm::cast<TCon>(Ty), Source);
propagateClassTycon(Class, llvm::cast<TCon>(Ty));
}
} else if (!Classes.empty()) {
DE.add<InvalidTypeToTypeclassDiagnostic>(Ty);
}
};
void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source) {
auto S = findInstanceContext(Ty, Class, Source);
void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty) {
auto S = findInstanceContext(Ty, Class);
for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) {
propagateClasses(Classes, Arg, Source);
propagateClasses(Classes, Arg);
}
};
class ArrowCursor {
std::stack<std::tuple<TArrow*, std::size_t>> Path;
public:
ArrowCursor(TArrow* Arr) {
Path.push({ Arr, 0 });
}
Type* next() {
while (!Path.empty()) {
auto& [Arr, I] = Path.top();
Type* Ty;
if (I == -1) {
Path.pop();
continue;
}
if (I == Arr->ParamTypes.size()) {
I = -1;
Ty = Arr->ReturnType;
} else {
Ty = Arr->ParamTypes[I];
I++;
}
if (llvm::isa<TArrow>(Ty)) {
Path.push({ static_cast<TArrow*>(Ty), 0 });
} else {
return Ty;
}
}
return nullptr;
}
};
void Checker::solveCEqual(CEqual* C) {
std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl;
if (!unify(C->Left, C->Right, C->Source)) {
DE.add<UnificationErrorDiagnostic>(simplify(C->Left), simplify(C->Right), C->Source);
}
/* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */
OrigLeft = C->Left;
OrigRight = C->Right;
Source = C->Source;
unify(C->Left, C->Right);
LeftPath = {};
RightPath = {};
/* DE.add<UnificationErrorDiagnostic>(simplify(C->Left), simplify(C->Right), C->Source); */
}
Type* Checker::simplify(Type* Ty) {
@ -1314,11 +1123,11 @@ namespace bolt {
return Ty;
}
void Checker::join(TVar* TV, Type* Ty, Node* Source) {
void Checker::join(TVar* TV, Type* Ty) {
Solution[TV] = Ty;
propagateClasses(TV->Contexts, Ty, Source);
propagateClasses(TV->Contexts, Ty);
// This is a very specific adjustment that is critical to the
// well-functioning of the infer/unify algorithm. When addConstraint() is
@ -1335,24 +1144,62 @@ namespace bolt {
}
bool Checker::unify(Type* A, Type* B, Node* Source) {
void Checker::unifyError() {
DE.add<UnificationErrorDiagnostic>(
simplify(OrigLeft),
simplify(OrigRight),
LeftPath,
RightPath,
Source
);
}
auto find = [&](auto OrigTy) {
auto Ty = OrigTy;
if (llvm::isa<TVar>(Ty)) {
auto TV = static_cast<TVar*>(Ty);
do {
auto Match = Solution.find(static_cast<TVar*>(Ty));
if (Match == Solution.end()) {
break;
}
Ty = Match->second;
} while (Ty->getKind() == TypeKind::Var);
// FIXME does this actually improove performance?
Solution[TV] = Ty;
class ArrowCursor {
std::stack<std::tuple<TArrow*, bool>> Stack;
TypePath& Path;
std::size_t I;
public:
ArrowCursor(TArrow* Arr, TypePath& Path):
Path(Path) {
Stack.push({ Arr, true });
Path.push_back(Arr->getStartIndex());
}
return Ty;
};
Type* next() {
while (!Stack.empty()) {
auto& [Arrow, First] = Stack.top();
auto& Index = Path.back();
if (!First) {
Index.advance(Arrow);
} else {
First = false;
}
Type* Ty;
if (Index == Arrow->getEndIndex()) {
Path.pop_back();
Stack.pop();
continue;
}
Ty = Arrow->resolve(Index);
if (llvm::isa<TArrow>(Ty)) {
auto NewIndex = Arrow->getStartIndex();
Stack.push({ static_cast<TArrow*>(Ty), true });
Path.push_back(NewIndex);
} else {
return Ty;
}
}
return nullptr;
}
};
bool Checker::unify(Type* A, Type* B) {
A = simplify(A);
B = simplify(B);
@ -1362,6 +1209,7 @@ namespace bolt {
auto Var2 = static_cast<TVar*>(B);
if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) {
if (Var1->Id != Var2->Id) {
unifyError();
return false;
}
return true;
@ -1373,38 +1221,47 @@ namespace bolt {
From = Var2;
} else {
// Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification
// Either way, Var1 is a good candidate for being unified away
// Either way, Var1, being Unification, is a good candidate for being unified away
To = Var2;
From = Var1;
}
join(From, To, Source);
propagateClasses(From->Contexts, To, Source);
join(From, To);
return true;
}
if (llvm::isa<TVar>(A)) {
auto TV = static_cast<TVar*>(A);
// Rigid type variables can never unify with antything else than what we
// have already handled in the previous if-statement, so issue an error.
if (TV->getVarKind() == VarKind::Rigid) {
unifyError();
return false;
}
// Occurs check
if (B->hasTypeVar(TV)) {
// NOTE Just like GHC, we just display an error message indicating that
// A cannot match B, e.g. a cannot match [a]. It looks much better
// than obsure references to an occurs check
unifyError();
return false;
}
join(TV, B, Source);
join(TV, B);
return true;
}
if (llvm::isa<TVar>(B)) {
return unify(B, A, Source);
return unify(B, A);
}
if (llvm::isa<TArrow>(A) && llvm::isa<TArrow>(B)) {
auto C1 = ArrowCursor(static_cast<TArrow*>(A));
auto C2 = ArrowCursor(static_cast<TArrow*>(B));
auto C1 = ArrowCursor(static_cast<TArrow*>(A), LeftPath);
auto C2 = ArrowCursor(static_cast<TArrow*>(B), RightPath);
bool Success = true;
for (;;) {
auto T1 = C1.next();
auto T2 = C2.next();
@ -1412,13 +1269,15 @@ namespace bolt {
break;
}
if (T1 == nullptr || T2 == nullptr) {
return false;
unifyError();
Success = false;
break;
}
if (!unify(T1, T2, Source)) {
return false;
if (!unify(T1, T2)) {
Success = false;
}
}
return true;
return Success;
/* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */
/* return false; */
/* } */
@ -1434,26 +1293,31 @@ namespace bolt {
if (llvm::isa<TArrow>(A)) {
auto Arr = static_cast<TArrow*>(A);
if (Arr->ParamTypes.empty()) {
return unify(Arr->ReturnType, B, Source);
return unify(Arr->ReturnType, B);
}
}
if (llvm::isa<TArrow>(B)) {
return unify(B, A, Source);
return unify(B, A);
}
if (llvm::isa<TTuple>(A) && llvm::isa<TTuple>(B)) {
auto Tuple1 = static_cast<TTuple*>(A);
auto Tuple2 = static_cast<TTuple*>(B);
if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) {
unifyError();
return false;
}
auto Count = Tuple1->ElementTypes.size();
bool Success = true;
for (size_t I = 0; I < Count; I++) {
if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], Source)) {
LeftPath.push_back(TypeIndex::forTupleElement(I));
RightPath.push_back(TypeIndex::forTupleElement(I));
if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I])) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
}
return Success;
}
@ -1468,18 +1332,25 @@ namespace bolt {
auto Con1 = static_cast<TCon*>(A);
auto Con2 = static_cast<TCon*>(B);
if (Con1->Id != Con2->Id) {
unifyError();
return false;
}
ZEN_ASSERT(Con1->Args.size() == Con2->Args.size());
auto Count = Con1->Args.size();
bool Success = true;
for (std::size_t I = 0; I < Count; I++) {
if (!unify(Con1->Args[I], Con2->Args[I], Source)) {
return false;
LeftPath.push_back(TypeIndex::forConArg(I));
RightPath.push_back(TypeIndex::forConArg(I));
if (!unify(Con1->Args[I], Con2->Args[I])) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
}
return true;
return Success;
}
unifyError();
return false;
}

View file

@ -438,14 +438,29 @@ namespace bolt {
setBold(true);
Out << "error: ";
resetStyles();
Out << "the types " << ANSI_FG_GREEN << describe(E.Left) << ANSI_RESET
<< " and " << ANSI_FG_GREEN << describe(E.Right) << ANSI_RESET << " failed to match\n\n";
auto Left = E.Left->resolve(E.LeftPath);
auto Right = E.Right->resolve(E.RightPath);
Out << "the types " << ANSI_FG_GREEN << describe(Left) << ANSI_RESET
<< " and " << ANSI_FG_GREEN << describe(Right) << ANSI_RESET << " failed to match\n\n";
if (E.Source) {
auto Range = E.Source->getRange();
//std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n";
writeExcerpt(E.Source->getSourceFile()->getTextFile(), Range, Range, Color::Red);
Out << "\n";
}
if (!E.LeftPath.empty()) {
setForegroundColor(Color::Yellow);
setBold(true);
Out << " info: ";
resetStyles();
Out << "type " << ANSI_FG_GREEN << describe(Left) << ANSI_RESET << " occurs in the full type " << ANSI_FG_GREEN << describe(E.Left) << ANSI_RESET << "\n\n";
}
if (!E.RightPath.empty()) {
setForegroundColor(Color::Yellow);
setBold(true);
Out << " info: ";
resetStyles();
Out << "type " << ANSI_FG_GREEN << describe(Right) << ANSI_RESET << " occurs in the full type " << ANSI_FG_GREEN << describe(E.Right) << ANSI_RESET << "\n\n";
}
return;
}

369
src/Types.cc Normal file
View file

@ -0,0 +1,369 @@
#include "llvm/Support/Casting.h"
#include "zen/config.hpp"
#include "zen/range.hpp"
#include "bolt/Type.hpp"
namespace bolt {
bool TypeclassSignature::operator<(const TypeclassSignature& Other) const {
if (Id < Other.Id) {
return true;
}
ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1);
return Params[0]->Id < Other.Params[0]->Id;
}
bool TypeclassSignature::operator==(const TypeclassSignature& Other) const {
ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1);
return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id;
}
bool TypeIndex::operator==(const TypeIndex& Other) const noexcept {
if (Kind != Other.Kind) {
return false;
}
switch (Kind) {
case TypeIndexKind::ConArg:
case TypeIndexKind::ArrowParamType:
case TypeIndexKind::TupleElement:
return I == Other.I;
default:
return true;
}
}
void TypeIndex::advance(const Type* Ty) {
switch (Kind) {
case TypeIndexKind::End:
break;
case TypeIndexKind::ArrowParamType:
{
auto Arrow = llvm::cast<TArrow>(Ty);
if (I+1 < Arrow->ParamTypes.size()) {
++I;
} else {
Kind = TypeIndexKind::ArrowReturnType;
}
break;
}
case TypeIndexKind::ArrowReturnType:
Kind = TypeIndexKind::End;
break;
case TypeIndexKind::ConArg:
{
auto Con = llvm::cast<TCon>(Ty);
if (I+1 < Con->Args.size()) {
++I;
} else {
Kind = TypeIndexKind::End;
}
break;
}
case TypeIndexKind::TupleElement:
{
auto Tuple = llvm::cast<TTuple>(Ty);
if (I+1 < Tuple->ElementTypes.size()) {
++I;
} else {
Kind = TypeIndexKind::End;
}
break;
}
}
}
void Type::addTypeVars(TVSet& TVs) {
switch (Kind) {
case TypeKind::Var:
TVs.emplace(static_cast<TVar*>(this));
break;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
for (auto Ty: Arrow->ParamTypes) {
Ty->addTypeVars(TVs);
}
Arrow->ReturnType->addTypeVars(TVs);
break;
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
for (auto Ty: Con->Args) {
Ty->addTypeVars(TVs);
}
break;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
Index->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Tuple->ElementTypes) {
Ty->addTypeVars(TVs);
}
break;
}
}
}
bool Type::hasTypeVar(const TVar* TV) {
switch (Kind) {
case TypeKind::Var:
return static_cast<TVar*>(this)->Id == TV->Id;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
for (auto Ty: Arrow->ParamTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return Arrow->ReturnType->hasTypeVar(TV);
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
for (auto Ty: Con->Args) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
return Index->Ty->hasTypeVar(TV);
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Tuple->ElementTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
}
}
Type* Type::substitute(const TVSub &Sub) {
switch (Kind) {
case TypeKind::Var:
{
auto TV = static_cast<TVar*>(this);
auto Match = Sub.find(TV);
return Match != Sub.end() ? Match->second->substitute(Sub) : this;
}
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
bool Changed = false;
std::vector<Type*> NewParamTypes;
for (auto Ty: Arrow->ParamTypes) {
auto NewParamType = Ty->substitute(Sub);
if (NewParamType != Ty) {
Changed = true;
}
NewParamTypes.push_back(NewParamType);
}
auto NewRetTy = Arrow->ReturnType->substitute(Sub) ;
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
bool Changed = false;
std::vector<Type*> NewArgs;
for (auto Arg: Con->Args) {
auto NewArg = Arg->substitute(Sub);
if (NewArg != Arg) {
Changed = true;
}
NewArgs.push_back(NewArg);
}
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
}
case TypeKind::TupleIndex:
{
auto Tuple = static_cast<TTupleIndex*>(this);
auto NewTy = Tuple->Ty->substitute(Sub);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = Ty->substitute(Sub);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : this;
}
}
}
Type* Type::resolve(const TypeIndex& Index) const noexcept {
switch (Index.Kind) {
case TypeIndexKind::ConArg:
return llvm::cast<TCon>(this)->Args[Index.I];
case TypeIndexKind::TupleElement:
return llvm::cast<TTuple>(this)->ElementTypes[Index.I];
case TypeIndexKind::ArrowParamType:
return llvm::cast<TArrow>(this)->ParamTypes[Index.I];
case TypeIndexKind::ArrowReturnType:
return llvm::cast<TArrow>(this)->ReturnType;
case TypeIndexKind::End:
ZEN_UNREACHABLE
}
ZEN_UNREACHABLE
}
bool Type::operator==(const Type& Other) const noexcept {
switch (Kind) {
case TypeKind::Var:
if (Other.Kind != TypeKind::Var) {
return false;
}
return static_cast<const TVar*>(this)->Id == static_cast<const TVar&>(Other).Id;
case TypeKind::Tuple:
{
if (Other.Kind != TypeKind::Tuple) {
return false;
}
auto A = static_cast<const TTuple&>(*this);
auto B = static_cast<const TTuple&>(Other);
if (A.ElementTypes.size() != B.ElementTypes.size()) {
return false;
}
for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
case TypeKind::TupleIndex:
{
if (Other.Kind != TypeKind::TupleIndex) {
return false;
}
auto A = static_cast<const TTupleIndex&>(*this);
auto B = static_cast<const TTupleIndex&>(Other);
return A.I == B.I && *A.Ty == *B.Ty;
}
case TypeKind::Con:
{
if (Other.Kind != TypeKind::Con) {
return false;
}
auto A = static_cast<const TCon&>(*this);
auto B = static_cast<const TCon&>(Other);
if (A.Id != B.Id) {
return false;
}
if (A.Args.size() != B.Args.size()) {
return false;
}
for (auto [T1, T2]: zen::zip(A.Args, B.Args)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
case TypeKind::Arrow:
{
// FIXME Do we really need to 'curry' this type?
if (Other.Kind != TypeKind::Arrow) {
return false;
}
auto A = static_cast<const TArrow&>(*this);
auto B = static_cast<const TArrow&>(Other);
/* ArrowCursor C1 { &A }; */
/* ArrowCursor C2 { &B }; */
/* for (;;) { */
/* auto T1 = C1.next(); */
/* auto T2 = C2.next(); */
/* if (T1 == nullptr && T2 == nullptr) { */
/* break; */
/* } */
/* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */
/* return false; */
/* } */
/* } */
if (A.ParamTypes.size() != B.ParamTypes.size()) {
return false;
}
for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) {
if (*T1 != *T2) {
return false;
}
}
return A.ReturnType != B.ReturnType;
}
}
}
TypeIterator Type::begin() {
return TypeIterator { this, getStartIndex() };
}
TypeIterator Type::end() {
return TypeIterator { this, getEndIndex() };
}
TypeIndex Type::getStartIndex() {
switch (Kind) {
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
if (Con->Args.empty()) {
return TypeIndex(TypeIndexKind::End);
}
return TypeIndex::forConArg(0);
}
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
if (Arrow->ParamTypes.empty()) {
return TypeIndex::forArrowReturnType();
}
return TypeIndex::forArrowParamType(0);
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
if (Tuple->ElementTypes.empty()) {
return TypeIndex(TypeIndexKind::End);
}
return TypeIndex::forTupleElement(0);
}
default:
return TypeIndex(TypeIndexKind::End);
}
}
TypeIndex Type::getEndIndex() {
return TypeIndex(TypeIndexKind::End);
}
}