Make Type a union and fix checking of tuple access

This commit is contained in:
Sam Vervaeck 2024-01-21 00:18:09 +01:00
parent 75124d097b
commit 4e11af005c
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
8 changed files with 1092 additions and 826 deletions

View file

@ -184,6 +184,12 @@ namespace bolt {
template<typename T> template<typename T>
NodeKind getNodeType(); NodeKind getNodeType();
enum NodeFlags {
NodeFlags_TypeIsSolved = 1,
};
using NodeFlagsMask = unsigned;
class Node { class Node {
unsigned RefCount = 1; unsigned RefCount = 1;
@ -192,6 +198,7 @@ namespace bolt {
public: public:
NodeFlagsMask Flags = 0;
Node* Parent = nullptr; Node* Parent = nullptr;
inline void ref() { inline void ref() {

View file

@ -178,6 +178,7 @@ namespace bolt {
Type* ListType; Type* ListType;
Type* IntType; Type* IntType;
Type* StringType; Type* StringType;
Type* UnitType;
Graph<Node*> RefGraph; Graph<Node*> RefGraph;
@ -217,9 +218,9 @@ namespace bolt {
/// Factory methods /// Factory methods
TCon* createConType(ByteString Name); Type* createConType(ByteString Name);
TVar* createTypeVar(); Type* createTypeVar();
TVarRigid* createRigidVar(ByteString Name); Type* createRigidVar(ByteString Name);
InferContext* createInferContext( InferContext* createInferContext(
InferContext* Parent = nullptr, InferContext* Parent = nullptr,
TVSet* TVs = new TVSet, TVSet* TVs = new TVSet,
@ -280,6 +281,11 @@ namespace bolt {
*/ */
Type* simplifyType(Type* Ty); Type* simplifyType(Type* Ty);
/**
* \internal
*/
Type* solveType(Type* Ty);
void check(SourceFile* SF); void check(SourceFile* SF);
inline Type* getBoolType() const { inline Type* getBoolType() const {

View file

@ -22,12 +22,17 @@ namespace bolt {
public: public:
bool FailOnError = false;
inline bool hasError() const noexcept { inline bool hasError() const noexcept {
return HasError; return HasError;
} }
template<typename D, typename ...Ts> template<typename D, typename ...Ts>
void add(Ts&&... Args) { void add(Ts&&... Args) {
// if (FailOnError) {
// ZEN_PANIC("An error diagnostic caused the program to abort.");
// }
HasError = true; HasError = true;
addDiagnostic(new D { std::forward<Ts>(Args)... }); addDiagnostic(new D { std::forward<Ts>(Args)... });
} }

View file

@ -2,7 +2,6 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include <memory>
#include "bolt/ByteString.hpp" #include "bolt/ByteString.hpp"
#include "bolt/String.hpp" #include "bolt/String.hpp"
@ -12,15 +11,16 @@
namespace bolt { namespace bolt {
enum class DiagnosticKind : unsigned char { enum class DiagnosticKind : unsigned char {
UnexpectedToken,
UnexpectedString,
BindingNotFound, BindingNotFound,
UnificationError,
TypeclassMissing,
InstanceNotFound,
TupleIndexOutOfRange,
InvalidTypeToTypeclass,
FieldNotFound, FieldNotFound,
InstanceNotFound,
InvalidTypeToTypeclass,
NotATuple,
TupleIndexOutOfRange,
TypeclassMissing,
UnexpectedString,
UnexpectedToken,
UnificationError,
}; };
class Diagnostic : std::runtime_error { class Diagnostic : std::runtime_error {
@ -168,10 +168,10 @@ namespace bolt {
class TupleIndexOutOfRangeDiagnostic : public Diagnostic { class TupleIndexOutOfRangeDiagnostic : public Diagnostic {
public: public:
TTuple* Tuple; Type* Tuple;
std::size_t I; std::size_t I;
inline TupleIndexOutOfRangeDiagnostic(TTuple* Tuple, std::size_t I): inline TupleIndexOutOfRangeDiagnostic(Type* Tuple, std::size_t I):
Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I) {} Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I) {}
unsigned getCode() const noexcept override { unsigned getCode() const noexcept override {
@ -217,4 +217,18 @@ namespace bolt {
}; };
class NotATupleDiagnostic : public Diagnostic {
public:
Type* Ty;
inline NotATupleDiagnostic(Type* Ty):
Diagnostic(DiagnosticKind::NotATuple), Ty(Ty) {}
unsigned getCode() const noexcept override {
return 2016;
}
};
} }

View file

@ -2,17 +2,21 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <type_traits> #include <optional>
#include <vector> #include <unistd.h>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <unordered_map>
#include <vector>
#include "zen/config.hpp"
#include "zen/range.hpp"
#include "bolt/CST.hpp"
#include "bolt/ByteString.hpp" #include "bolt/ByteString.hpp"
namespace bolt { namespace bolt {
class Type; class Type;
class TVar;
class TCon; class TCon;
using TypeclassId = ByteString; using TypeclassId = ByteString;
@ -23,7 +27,7 @@ namespace bolt {
using TypeclassId = ByteString; using TypeclassId = ByteString;
TypeclassId Id; TypeclassId Id;
std::vector<TVar*> Params; std::vector<Type*> Params;
bool operator<(const TypeclassSignature& Other) const; bool operator<(const TypeclassSignature& Other) const;
bool operator==(const TypeclassSignature& Other) const; bool operator==(const TypeclassSignature& Other) const;
@ -144,8 +148,8 @@ namespace bolt {
using TypePath = std::vector<TypeIndex>; using TypePath = std::vector<TypeIndex>;
using TVSub = std::unordered_map<TVar*, Type*>; using TVSub = std::unordered_map<Type*, Type*>;
using TVSet = std::unordered_set<TVar*>; using TVSet = std::unordered_set<Type*>;
enum class TypeKind : unsigned char { enum class TypeKind : unsigned char {
Var, Var,
@ -160,48 +164,402 @@ namespace bolt {
Present, Present,
}; };
class Type { class Type;
const TypeKind Kind; struct TCon {
size_t Id;
ByteString DisplayName;
protected: bool operator==(const TCon& Other) const;
inline Type(TypeKind Kind): };
Kind(Kind) {}
public: struct TApp {
Type* Op;
Type* Arg;
inline TypeKind getKind() const noexcept { bool operator==(const TApp& Other) const;
};
enum class VarKind {
Rigid,
Unification,
};
struct TVar {
VarKind VK;
size_t Id;
TypeclassContext Context;
std::optional<ByteString> Name;
std::optional<TypeclassContext> Provided;
VarKind getKind() const {
return VK;
}
bool isUni() const {
return VK == VarKind::Unification;
}
bool isRigid() const {
return VK == VarKind::Rigid;
}
bool operator==(const TVar& Other) const;
};
struct TArrow {
Type* ParamType;
Type* ReturnType;
bool operator==(const TArrow& Other) const;
};
struct TTuple {
std::vector<Type*> ElementTypes;
bool operator==(const TTuple& Other) const;
};
struct TTupleIndex {
Type* Ty;
std::size_t I;
bool operator==(const TTupleIndex& Other) const;
};
struct TNil {
bool operator==(const TNil& Other) const;
};
struct TField {
ByteString Name;
Type* Ty;
Type* RestTy;
bool operator==(const TField& Other) const;
};
struct TAbsent {
bool operator==(const TAbsent& Other) const;
};
struct TPresent {
Type* Ty;
bool operator==(const TPresent& Other) const;
};
struct Type {
TypeKind Kind;
Type* Parent = this;
union {
TCon Con;
TApp App;
TVar Var;
TArrow Arrow;
TTuple Tuple;
TTupleIndex TupleIndex;
TNil Nil;
TField Field;
TAbsent Absent;
TPresent Present;
};
Type(TCon&& Con):
Kind(TypeKind::Con), Con(std::move(Con)) {};
Type(TApp&& App):
Kind(TypeKind::App), App(std::move(App)) {};
Type(TVar&& Var):
Kind(TypeKind::Var), Var(std::move(Var)) {};
Type(TArrow&& Arrow):
Kind(TypeKind::Arrow), Arrow(std::move(Arrow)) {};
Type(TTuple&& Tuple):
Kind(TypeKind::Tuple), Tuple(std::move(Tuple)) {};
Type(TTupleIndex&& TupleIndex):
Kind(TypeKind::TupleIndex), TupleIndex(std::move(TupleIndex)) {};
Type(TNil&& Nil):
Kind(TypeKind::Nil), Nil(std::move(Nil)) {};
Type(TField&& Field):
Kind(TypeKind::Field), Field(std::move(Field)) {};
Type(TAbsent&& Absent):
Kind(TypeKind::Absent), Absent(std::move(Absent)) {};
Type(TPresent&& Present):
Kind(TypeKind::Present), Present(std::move(Present)) {};
// Type(TCon Con): Kind(TypeKind::Con) {
// new (&Con)TCon(Con);
// }
// Type(TApp App): Kind(TypeKind::App) {
// new (&App)TApp(App);
// }
// Type(TVar Var): Kind(TypeKind::Var) {
// new (&Var)TVar(Var);
// }
// Type(TArrow Arrow): Kind(TypeKind::Arrow) {
// new (&Arrow)TArrow(Arrow);
// }
// Type(TTuple Tuple): Kind(TypeKind::Tuple) {
// new (&Tuple)TTuple(Tuple);
// }
// Type(TTupleIndex TupleIndex): Kind(TypeKind::TupleIndex) {
// new (&TupleIndex)TTupleIndex(TupleIndex);
// }
// Type(TNil Nil): Kind(TypeKind::Nil) {
// new (&Nil)TNil(Nil);
// }
// Type(TField Field): Kind(TypeKind::Field) {
// new (&Field)TField(Field);
// }
// Type(TAbsent Absent): Kind(TypeKind::Absent) {
// new (&Absent)TAbsent(Absent);
// }
// Type(TPresent Present): Kind(TypeKind::Present) {
// new (&Present)TPresent(Present);
// }
Type(const Type& Other): Kind(Other.Kind) {
switch (Kind) {
case TypeKind::Con:
new (&Con)TCon(Other.Con);
break;
case TypeKind::App:
new (&App)TApp(Other.App);
break;
case TypeKind::Var:
new (&Var)TVar(Other.Var);
break;
case TypeKind::Arrow:
new (&Arrow)TArrow(Other.Arrow);
break;
case TypeKind::Tuple:
new (&Tuple)TTuple(Other.Tuple);
break;
case TypeKind::TupleIndex:
new (&TupleIndex)TTupleIndex(Other.TupleIndex);
break;
case TypeKind::Nil:
new (&Nil)TNil(Other.Nil);
break;
case TypeKind::Field:
new (&Field)TField(Other.Field);
break;
case TypeKind::Absent:
new (&Absent)TAbsent(Other.Absent);
break;
case TypeKind::Present:
new (&Present)TPresent(Other.Present);
break;
}
}
Type(Type&& Other): Kind(std::move(Other.Kind)) {
switch (Kind) {
case TypeKind::Con:
new (&Con)TCon(std::move(Other.Con));
break;
case TypeKind::App:
new (&App)TApp(std::move(Other.App));
break;
case TypeKind::Var:
new (&Var)TVar(std::move(Other.Var));
break;
case TypeKind::Arrow:
new (&Arrow)TArrow(std::move(Other.Arrow));
break;
case TypeKind::Tuple:
new (&Tuple)TTuple(std::move(Other.Tuple));
break;
case TypeKind::TupleIndex:
new (&TupleIndex)TTupleIndex(std::move(Other.TupleIndex));
break;
case TypeKind::Nil:
new (&Nil)TNil(std::move(Other.Nil));
break;
case TypeKind::Field:
new (&Field)TField(std::move(Other.Field));
break;
case TypeKind::Absent:
new (&Absent)TAbsent(std::move(Other.Absent));
break;
case TypeKind::Present:
new (&Present)TPresent(std::move(Other.Present));
break;
}
}
TypeKind getKind() const {
return Kind; return Kind;
} }
bool hasTypeVar(const TVar* TV); bool isVarRigid() const {
return Kind == TypeKind::Var
void addTypeVars(TVSet& TVs); && asVar().getKind() == VarKind::Rigid;
inline TVSet getTypeVars() {
TVSet Out;
addTypeVars(Out);
return Out;
} }
/** bool isVar() const {
* Rewrites the entire substructure of a type to another one. return Kind == TypeKind::Var;
* }
* \param Recursive If true, a succesfull local rewritten type will be again
* rewriten until it encounters some terminals.
*/
Type* rewrite(std::function<Type*(Type*)> Fn, bool Recursive = false);
Type* substitute(const TVSub& Sub); TVar& asVar() {
ZEN_ASSERT(Kind == TypeKind::Var);
return Var;
}
Type* solve(); const TVar& asVar() const {
ZEN_ASSERT(Kind == TypeKind::Var);
return Var;
}
TypeIterator begin(); bool isApp() const {
TypeIterator end(); return Kind == TypeKind::App;
}
TypeIndex getStartIndex(); TApp& asApp() {
TypeIndex getEndIndex(); ZEN_ASSERT(Kind == TypeKind::App);
return App;
}
const TApp& asApp() const {
ZEN_ASSERT(Kind == TypeKind::App);
return App;
}
bool isCon() const {
return Kind == TypeKind::Con;
}
TCon& asCon() {
ZEN_ASSERT(Kind == TypeKind::Con);
return Con;
}
const TCon& asCon() const {
ZEN_ASSERT(Kind == TypeKind::Con);
return Con;
}
bool isArrow() const {
return Kind == TypeKind::Arrow;
}
TArrow& asArrow() {
ZEN_ASSERT(Kind == TypeKind::Arrow);
return Arrow;
}
const TArrow& asArrow() const {
ZEN_ASSERT(Kind == TypeKind::Arrow);
return Arrow;
}
bool isTuple() const {
return Kind == TypeKind::Tuple;
}
TTuple& asTuple() {
ZEN_ASSERT(Kind == TypeKind::Tuple);
return Tuple;
}
const TTuple& asTuple() const {
ZEN_ASSERT(Kind == TypeKind::Tuple);
return Tuple;
}
bool isTupleIndex() const {
return Kind == TypeKind::TupleIndex;
}
TTupleIndex& asTupleIndex() {
ZEN_ASSERT(Kind == TypeKind::TupleIndex);
return TupleIndex;
}
const TTupleIndex& asTupleIndex() const {
ZEN_ASSERT(Kind == TypeKind::TupleIndex);
return TupleIndex;
}
bool isField() const {
return Kind == TypeKind::Field;
}
TField& asField() {
ZEN_ASSERT(Kind == TypeKind::Field);
return Field;
}
const TField& asField() const {
ZEN_ASSERT(Kind == TypeKind::Field);
return Field;
}
bool isAbsent() const {
return Kind == TypeKind::Absent;
}
TAbsent& asAbsent() {
ZEN_ASSERT(Kind == TypeKind::Absent);
return Absent;
}
const TAbsent& asAbsent() const {
ZEN_ASSERT(Kind == TypeKind::Absent);
return Absent;
}
bool isPresent() const {
return Kind == TypeKind::Present;
}
TPresent& asPresent() {
ZEN_ASSERT(Kind == TypeKind::Present);
return Present;
}
const TPresent& asPresent() const {
ZEN_ASSERT(Kind == TypeKind::Present);
return Present;
}
bool isNil() const {
return Kind == TypeKind::Nil;
}
TNil& asNil() {
ZEN_ASSERT(Kind == TypeKind::Nil);
return Nil;
}
const TNil& asNil() const {
ZEN_ASSERT(Kind == TypeKind::Nil);
return Nil;
}
Type* rewrite(std::function<Type*(Type*)> Fn, bool Recursive = true);
Type* resolve(const TypeIndex& Index) const noexcept; Type* resolve(const TypeIndex& Index) const noexcept;
@ -213,207 +571,128 @@ namespace bolt {
return Ty; return Ty;
} }
bool operator==(const Type& Other) const noexcept; void set(Type* Ty) {
auto Root = find();
bool operator!=(const Type& Other) const noexcept { // It is not possible to set a solution twice.
return !(*this == Other); if (isVar()) {
ZEN_ASSERT(Root->isVar());
}
Root->Parent = Ty;
} }
}; Type* find() const {
Type* Curr = const_cast<Type*>(this);
class TCon : public Type { for (;;) {
public: auto Keep = Curr->Parent;
if (Keep == Curr) {
const size_t Id; return Keep;
ByteString DisplayName; }
Curr->Parent = Keep->Parent;
inline TCon(const size_t Id, ByteString DisplayName): Curr = Keep;
Type(TypeKind::Con), Id(Id), DisplayName(DisplayName) {} }
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Con;
} }
}; bool operator==(const Type& Other) const;
class TApp : public Type { void destroy() {
public: switch (Kind) {
case TypeKind::Con:
Type* Op; App.~TApp();
Type* Arg; break;
case TypeKind::App:
inline TApp(Type* Op, Type* Arg): App.~TApp();
Type(TypeKind::App), Op(Op), Arg(Arg) {} break;
case TypeKind::Var:
static bool classof(const Type* Ty) { Var.~TVar();
return Ty->getKind() == TypeKind::App; break;
case TypeKind::Arrow:
Arrow.~TArrow();
break;
case TypeKind::Tuple:
Tuple.~TTuple();
break;
case TypeKind::TupleIndex:
TupleIndex.~TTupleIndex();
break;
case TypeKind::Nil:
Nil.~TNil();
break;
case TypeKind::Field:
Field.~TField();
break;
case TypeKind::Absent:
Absent.~TAbsent();
break;
case TypeKind::Present:
Present.~TPresent();
break;
}
} }
}; Type& operator=(Type& Other) {
destroy();
enum class VarKind { Kind = Other.Kind;
Rigid, switch (Kind) {
Unification, case TypeKind::Con:
}; App = Other.App;
break;
class TVar : public Type { case TypeKind::App:
App = Other.App;
Type* Parent = this; break;
case TypeKind::Var:
public: Var = Other.Var;
break;
const size_t Id; case TypeKind::Arrow:
VarKind VK; Arrow = Other.Arrow;
break;
TypeclassContext Contexts; case TypeKind::Tuple:
Tuple = Other.Tuple;
inline TVar(size_t Id, VarKind VK): break;
Type(TypeKind::Var), Id(Id), VK(VK) {} case TypeKind::TupleIndex:
TupleIndex = Other.TupleIndex;
inline VarKind getVarKind() const noexcept { break;
return VK; case TypeKind::Nil:
Nil = Other.Nil;
break;
case TypeKind::Field:
Field = Other.Field;
break;
case TypeKind::Absent:
Absent = Other.Absent;
break;
case TypeKind::Present:
Present = Other.Present;
break;
}
return *this;
} }
inline bool isRigid() const noexcept { bool hasTypeVar(Type* TV) const;
return VK == VarKind::Rigid;
TypeIterator begin();
TypeIterator end();
TypeIndex getStartIndex() const;
TypeIndex getEndIndex() const;
Type* substitute(const TVSub& Sub);
void visitEachChild(std::function<void(Type*)> Proc);
TVSet getTypeVars();
~Type() {
destroy();
} }
Type* find(); static Type* buildArrow(std::vector<Type*> ParamTypes, Type* ReturnType) {
void set(Type* Ty);
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Var;
}
};
class TVarRigid : public TVar {
public:
ByteString Name;
TypeclassContext Provided;
inline TVarRigid(size_t Id, ByteString Name):
TVar(Id, VarKind::Rigid), Name(Name) {}
};
class TArrow : public Type {
public:
Type* ParamType;
Type* ReturnType;
inline TArrow(
Type* ParamType,
Type* ReturnType
): Type(TypeKind::Arrow),
ParamType(ParamType),
ReturnType(ReturnType) {}
static Type* build(std::vector<Type*> ParamTypes, Type* ReturnType) {
Type* Curr = ReturnType; Type* Curr = ReturnType;
for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) { for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) {
Curr = new TArrow(*Iter, Curr); Curr = new Type(TArrow(*Iter, Curr));
} }
return Curr; return Curr;
} }
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Arrow;
}
};
class TTuple : public Type {
public:
std::vector<Type*> ElementTypes;
inline TTuple(std::vector<Type*> ElementTypes):
Type(TypeKind::Tuple), ElementTypes(ElementTypes) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Tuple;
}
};
class TTupleIndex : public Type {
public:
Type* Ty;
std::size_t I;
inline TTupleIndex(Type* Ty, std::size_t I):
Type(TypeKind::TupleIndex), Ty(Ty), I(I) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::TupleIndex;
}
};
class TNil : public Type {
public:
inline TNil():
Type(TypeKind::Nil) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Nil;
}
};
class TField : public Type {
public:
ByteString Name;
Type* Ty;
Type* RestTy;
inline TField(
ByteString Name,
Type* Ty,
Type* RestTy
): Type(TypeKind::Field),
Name(Name),
Ty(Ty),
RestTy(RestTy) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Field;
}
};
class TAbsent : public Type {
public:
inline TAbsent():
Type(TypeKind::Absent) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Absent;
}
};
class TPresent : public Type {
public:
Type* Ty;
inline TPresent(Type* Ty):
Type(TypeKind::Present), Ty(Ty) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Present;
}
}; };
template<bool IsConst> template<bool IsConst>
@ -426,48 +705,49 @@ namespace bolt {
virtual void enterType(C<Type>* Ty) {} virtual void enterType(C<Type>* Ty) {}
virtual void exitType(C<Type>* Ty) {} virtual void exitType(C<Type>* Ty) {}
virtual void visitType(C<Type>* Ty) { // virtual void visitType(C<Type>* Ty) {
visitEachChild(Ty); // visitEachChild(Ty);
// }
virtual void visitVarType(C<TVar>& Ty) {
} }
virtual void visitVarType(C<TVar>* Ty) { virtual void visitAppType(C<TApp>& Ty) {
visitType(Ty); visit(Ty.Op);
visit(Ty.Arg);
} }
virtual void visitAppType(C<TApp>* Ty) { virtual void visitPresentType(C<TPresent>& Ty) {
visitType(Ty); visit(Ty.Ty);
} }
virtual void visitPresentType(C<TPresent>* Ty) { virtual void visitConType(C<TCon>& Ty) {
visitType(Ty);
} }
virtual void visitConType(C<TCon>* Ty) { virtual void visitArrowType(C<TArrow>& Ty) {
visitType(Ty); visit(Ty.ParamType);
visit(Ty.ReturnType);
} }
virtual void visitArrowType(C<TArrow>* Ty) { virtual void visitTupleType(C<TTuple>& Ty) {
visitType(Ty); for (auto ElTy: Ty.ElementTypes) {
visit(ElTy);
}
} }
virtual void visitTupleType(C<TTuple>* Ty) { virtual void visitTupleIndexType(C<TTupleIndex>& Ty) {
visitType(Ty); visit(Ty.Ty);
} }
virtual void visitTupleIndexType(C<TTupleIndex>* Ty) { virtual void visitAbsentType(C<TAbsent>& Ty) {
visitType(Ty);
} }
virtual void visitAbsentType(C<TAbsent>* Ty) { virtual void visitFieldType(C<TField>& Ty) {
visitType(Ty); visit(Ty.Ty);
visit(Ty.RestTy);
} }
virtual void visitFieldType(C<TField>* Ty) { virtual void visitNilType(C<TNil>& Ty) {
visitType(Ty);
}
virtual void visitNilType(C<TNil>* Ty) {
visitType(Ty);
} }
public: public:
@ -481,14 +761,14 @@ namespace bolt {
break; break;
case TypeKind::Arrow: case TypeKind::Arrow:
{ {
auto Arrow = static_cast<C<TArrow>*>(Ty); auto& Arrow = Ty->asArrow();
visit(Arrow->ParamType); visit(Arrow->ParamType);
visit(Arrow->ReturnType); visit(Arrow->ReturnType);
break; break;
} }
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
auto Tuple = static_cast<C<TTuple>*>(Ty); auto& Tuple = Ty->asTuple();
for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) { for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) {
visit(Tuple->ElementTypes[I]); visit(Tuple->ElementTypes[I]);
} }
@ -496,27 +776,27 @@ namespace bolt {
} }
case TypeKind::App: case TypeKind::App:
{ {
auto App = static_cast<C<TApp>*>(Ty); auto& App = Ty->asApp();
visit(App->Op); visit(App->Op);
visit(App->Arg); visit(App->Arg);
break; break;
} }
case TypeKind::Field: case TypeKind::Field:
{ {
auto Field = static_cast<C<TField>*>(Ty); auto& Field = Ty->asField();
visit(Field->Ty); visit(Field->Ty);
visit(Field->RestTy); visit(Field->RestTy);
break; break;
} }
case TypeKind::Present: case TypeKind::Present:
{ {
auto Present = static_cast<C<TPresent>*>(Ty); auto& Present = Ty->asPresent();
visit(Present->Ty); visit(Present->Ty);
break; break;
} }
case TypeKind::TupleIndex: case TypeKind::TupleIndex:
{ {
auto Index = static_cast<C<TTupleIndex>*>(Ty); auto& Index = Ty->asTupleIndex();
visit(Index->Ty); visit(Index->Ty);
break; break;
} }
@ -524,37 +804,41 @@ namespace bolt {
} }
void visit(C<Type>* Ty) { void visit(C<Type>* Ty) {
// Always look at the most solved solution
Ty = Ty->find();
enterType(Ty); enterType(Ty);
switch (Ty->getKind()) { switch (Ty->getKind()) {
case TypeKind::Present: case TypeKind::Present:
visitPresentType(static_cast<C<TPresent>*>(Ty)); visitPresentType(Ty->asPresent());
break; break;
case TypeKind::Absent: case TypeKind::Absent:
visitAbsentType(static_cast<C<TAbsent>*>(Ty)); visitAbsentType(Ty->asAbsent());
break; break;
case TypeKind::Nil: case TypeKind::Nil:
visitNilType(static_cast<C<TNil>*>(Ty)); visitNilType(Ty->asNil());
break; break;
case TypeKind::Field: case TypeKind::Field:
visitFieldType(static_cast<C<TField>*>(Ty)); visitFieldType(Ty->asField());
break; break;
case TypeKind::Con: case TypeKind::Con:
visitConType(static_cast<C<TCon>*>(Ty)); visitConType(Ty->asCon());
break; break;
case TypeKind::Arrow: case TypeKind::Arrow:
visitArrowType(static_cast<C<TArrow>*>(Ty)); visitArrowType(Ty->asArrow());
break; break;
case TypeKind::Var: case TypeKind::Var:
visitVarType(static_cast<C<TVar>*>(Ty)); visitVarType(Ty->asVar());
break; break;
case TypeKind::Tuple: case TypeKind::Tuple:
visitTupleType(static_cast<C<TTuple>*>(Ty)); visitTupleType(Ty->asTuple());
break; break;
case TypeKind::App: case TypeKind::App:
visitAppType(static_cast<C<TApp>*>(Ty)); visitAppType(Ty->asApp());
break; break;
case TypeKind::TupleIndex: case TypeKind::TupleIndex:
visitTupleIndexType(static_cast<C<TTupleIndex>*>(Ty)); visitTupleIndexType(Ty->asTupleIndex());
break; break;
} }
exitType(Ty); exitType(Ty);
@ -567,11 +851,4 @@ namespace bolt {
using TypeVisitor = TypeVisitorBase<false>; using TypeVisitor = TypeVisitorBase<false>;
using ConstTypeVisitor = TypeVisitorBase<true>; using ConstTypeVisitor = TypeVisitorBase<true>;
// template<typename T>
// struct DerefHash {
// std::size_t operator()(const T& Value) const noexcept {
// return std::hash<decltype(*Value)>{}(*Value);
// }
// };
} }

View file

@ -1,13 +1,11 @@
#include <algorithm> #include <algorithm>
#include <iterator>
#include <stack> #include <stack>
#include <map> #include <map>
#include "bolt/Type.hpp"
#include "zen/config.hpp" #include "zen/config.hpp"
#include "zen/range.hpp"
#include "bolt/Type.hpp"
#include "bolt/CSTVisitor.hpp" #include "bolt/CSTVisitor.hpp"
#include "bolt/DiagnosticEngine.hpp" #include "bolt/DiagnosticEngine.hpp"
#include "bolt/Diagnostics.hpp" #include "bolt/Diagnostics.hpp"
@ -39,29 +37,30 @@ namespace bolt {
Type* Checker::simplifyType(Type* Ty) { Type* Checker::simplifyType(Type* Ty) {
return Ty->rewrite([&](auto Ty) { Ty = Ty->find();
if (Ty->getKind() == TypeKind::Var) { if (Ty->isTupleIndex()) {
Ty = static_cast<TVar*>(Ty)->find(); auto& Index = Ty->asTupleIndex();
} auto MaybeTuple = simplifyType(Index.Ty);
if (MaybeTuple->isTuple()) {
if (Ty->getKind() == TypeKind::TupleIndex) { auto& Tuple = MaybeTuple->asTuple();
auto Index = static_cast<TTupleIndex*>(Ty); if (Index.I >= Tuple.ElementTypes.size()) {
auto MaybeTuple = simplifyType(Index->Ty); DE.add<TupleIndexOutOfRangeDiagnostic>(MaybeTuple, Index.I);
if (MaybeTuple->getKind() == TypeKind::Tuple) {
auto Tuple = static_cast<TTuple*>(MaybeTuple);
if (Index->I >= Tuple->ElementTypes.size()) {
DE.add<TupleIndexOutOfRangeDiagnostic>(Tuple, Index->I);
} else { } else {
Ty = simplifyType(Tuple->ElementTypes[Index->I]); auto ElementTy = simplifyType(Tuple.ElementTypes[Index.I]);
Ty->set(ElementTy);
Ty = ElementTy;
} }
} else if (!MaybeTuple->isVar()) {
DE.add<NotATupleDiagnostic>(MaybeTuple);
} }
} }
return Ty; return Ty;
}
}, /*Recursive=*/true); Type* Checker::solveType(Type* Ty) {
return Ty->rewrite([this](auto Ty) { return simplifyType(Ty); }, true);
} }
Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE):
@ -70,6 +69,7 @@ namespace bolt {
IntType = createConType("Int"); IntType = createConType("Int");
StringType = createConType("String"); StringType = createConType("String");
ListType = createConType("List"); ListType = createConType("List");
UnitType = new Type(TTuple({}));
} }
Scheme* Checker::lookup(ByteString Name) { Scheme* Checker::lookup(ByteString Name) {
@ -293,7 +293,7 @@ namespace bolt {
setContext(Decl->Ctx); setContext(Decl->Ctx);
std::vector<TVar*> Vars; std::vector<Type*> Vars;
for (auto TE: Decl->TVs) { for (auto TE: Decl->TVs) {
auto TV = createRigidVar(TE->Name->getCanonicalText()); auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV); Decl->Ctx->TVs->emplace(TV);
@ -312,13 +312,20 @@ namespace bolt {
auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member); auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member);
auto RetTy = Ty; auto RetTy = Ty;
for (auto Var: Vars) { for (auto Var: Vars) {
RetTy = new TApp(RetTy, Var); RetTy = new Type(TApp(RetTy, Var));
} }
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
for (auto Element: TupleMember->Elements) { for (auto Element: TupleMember->Elements) {
ParamTypes.push_back(inferTypeExpression(Element)); ParamTypes.push_back(inferTypeExpression(Element));
} }
Decl->Ctx->Parent->add(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, TArrow::build(ParamTypes, RetTy))); Decl->Ctx->Parent->add(
TupleMember->Name->getCanonicalText(),
new Forall(
Decl->Ctx->TVs,
Decl->Ctx->Constraints,
Type::buildArrow(ParamTypes, RetTy)
)
);
break; break;
} }
case NodeKind::RecordVariantDeclarationMember: case NodeKind::RecordVariantDeclarationMember:
@ -342,7 +349,7 @@ namespace bolt {
setContext(Decl->Ctx); setContext(Decl->Ctx);
std::vector<TVar*> Vars; std::vector<Type*> Vars;
for (auto TE: Decl->Vars) { for (auto TE: Decl->Vars) {
auto TV = createRigidVar(TE->Name->getCanonicalText()); auto TV = createRigidVar(TE->Name->getCanonicalText());
Vars.push_back(TV); Vars.push_back(TV);
@ -355,15 +362,28 @@ namespace bolt {
Decl->Ctx->Parent->add(Name, new Forall(Ty)); Decl->Ctx->Parent->add(Name, new Forall(Ty));
// Corresponds to the logic of one branch of a VariantDeclarationMember // Corresponds to the logic of one branch of a VariantDeclarationMember
Type* FieldsTy = new TNil(); Type* FieldsTy = new Type(TNil());
for (auto Field: Decl->Fields) { for (auto Field: Decl->Fields) {
FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy); FieldsTy = new Type(
TField(
Field->Name->getCanonicalText(),
new Type(TPresent(inferTypeExpression(Field->TypeExpression))),
FieldsTy
)
);
} }
Type* RetTy = Ty; Type* RetTy = Ty;
for (auto TV: Vars) { for (auto TV: Vars) {
RetTy = new TApp(RetTy, TV); RetTy = new Type(TApp(RetTy, TV));
} }
Decl->Ctx->Parent->add(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(FieldsTy, RetTy))); Decl->Ctx->Parent->add(
Name,
new Forall(
Decl->Ctx->TVs,
Decl->Ctx->Constraints,
new Type(TArrow(FieldsTy, RetTy))
)
);
popContext(); popContext();
break; break;
@ -444,11 +464,11 @@ namespace bolt {
auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) {
auto Id = Class->Name->getCanonicalText(); auto Id = Class->Name->getCanonicalText();
auto Ctx = &getContext(); auto Ctx = &getContext();
std::vector<TVar*> Out; std::vector<Type*> Out;
for (auto TE: Class->TypeVars) { for (auto TE: Class->TypeVars) {
auto Name = TE->Name->getCanonicalText(); auto Name = TE->Name->getCanonicalText();
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
TV->Contexts.emplace(Id); TV->asVar().Context.emplace(Id);
Ctx->add(Name, new Forall(TV)); Ctx->add(Name, new Forall(TV));
Out.push_back(TV); Out.push_back(TV);
} }
@ -586,7 +606,7 @@ namespace bolt {
RetType = createTypeVar(); RetType = createTypeVar();
} }
makeEqual(Decl->getType(), TArrow::build(ParamTypes, RetType), Decl); makeEqual(Decl->getType(), Type::buildArrow(ParamTypes, RetType), Decl);
setContext(OldCtx); setContext(OldCtx);
} }
@ -648,8 +668,8 @@ namespace bolt {
if (RetStmt->Expression) { if (RetStmt->Expression) {
makeEqual(inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression); makeEqual(inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression);
} else { } else {
ReturnType = new TTuple({}); ReturnType = UnitType;
makeEqual(new TTuple({}), getReturnType(), N); makeEqual(UnitType, getReturnType(), N);
} }
break; break;
} }
@ -691,18 +711,18 @@ namespace bolt {
} }
TCon* Checker::createConType(ByteString Name) { Type* Checker::createConType(ByteString Name) {
return new TCon(NextConTypeId++, Name); return new Type(TCon(NextConTypeId++, Name));
} }
TVarRigid* Checker::createRigidVar(ByteString Name) { Type* Checker::createRigidVar(ByteString Name) {
auto TV = new TVarRigid(NextTypeVarId++, Name); auto TV = new Type(TVar(VarKind::Rigid, NextTypeVarId++, {}, Name, {{}}));
getContext().TVs->emplace(TV); getContext().TVs->emplace(TV);
return TV; return TV;
} }
TVar* Checker::createTypeVar() { Type* Checker::createTypeVar() {
auto TV = new TVar(NextTypeVarId++, VarKind::Unification); auto TV = new Type(TVar(VarKind::Unification, NextTypeVarId++, {}));
getContext().TVs->emplace(TV); getContext().TVs->emplace(TV);
return TV; return TV;
} }
@ -727,7 +747,7 @@ namespace bolt {
for (auto TV: *F->TVs) { for (auto TV: *F->TVs) {
auto Fresh = createTypeVar(); auto Fresh = createTypeVar();
// std::cerr << describe(TV) << " => " << describe(Fresh) << std::endl; // std::cerr << describe(TV) << " => " << describe(Fresh) << std::endl;
Fresh->Contexts = TV->Contexts; Fresh->asVar().Context = TV->asVar().Context;
Sub[TV] = Fresh; Sub[TV] = Fresh;
} }
@ -736,8 +756,8 @@ namespace bolt {
// FIXME improve this // FIXME improve this
if (Constraint->getKind() == ConstraintKind::Equal) { if (Constraint->getKind() == ConstraintKind::Equal) {
auto Eq = static_cast<CEqual*>(Constraint); auto Eq = static_cast<CEqual*>(Constraint);
Eq->Left = simplifyType(Eq->Left); Eq->Left = solveType(Eq->Left);
Eq->Right = simplifyType(Eq->Right); Eq->Right = solveType(Eq->Right);
} }
auto NewConstraint = Constraint->substitute(Sub); auto NewConstraint = Constraint->substitute(Sub);
@ -752,11 +772,11 @@ namespace bolt {
addConstraint(NewConstraint); addConstraint(NewConstraint);
} }
// Note the call to simplify? This is because constraints may have already // This call to solve happens because constraints may have already
// been solved, with some unification variables being erased. To make // been solved, with some unification variables being erased. To make
// sure we instantiate unification variables that are still in use // sure we instantiate unification variables that are still in use
// we solve before substituting. // we solve before substituting.
return simplifyType(F->Type)->substitute(Sub); return solveType(F->Type)->substitute(Sub);
} }
} }
@ -771,10 +791,8 @@ namespace bolt {
std::vector<Type*> Types; std::vector<Type*> Types;
for (auto TE: D->TEs) { for (auto TE: D->TEs) {
auto Ty = inferTypeExpression(TE); auto Ty = inferTypeExpression(TE);
ZEN_ASSERT(Ty->getKind() == TypeKind::Var && static_cast<TVar*>(Ty)->isRigid()); Ty->asVar().Provided->emplace(D->Name->getCanonicalText());
auto TV = static_cast<TVarRigid*>(Ty); Types.push_back(Ty);
TV->Provided.emplace(D->Name->getCanonicalText());
Types.push_back(TV);
} }
break; break;
} }
@ -813,7 +831,7 @@ namespace bolt {
auto AppTE = static_cast<AppTypeExpression*>(N); auto AppTE = static_cast<AppTypeExpression*>(N);
Type* Ty = inferTypeExpression(AppTE->Op, IsPoly); Type* Ty = inferTypeExpression(AppTE->Op, IsPoly);
for (auto Arg: AppTE->Args) { for (auto Arg: AppTE->Args) {
Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly)); Ty = new Type(TApp(Ty, inferTypeExpression(Arg, IsPoly)));
} }
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
@ -830,9 +848,9 @@ namespace bolt {
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar(); Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty));
} }
ZEN_ASSERT(Ty->getKind() == TypeKind::Var); ZEN_ASSERT(Ty->isVar());
N->setType(Ty); N->setType(Ty);
return static_cast<TVar*>(Ty); return Ty;
} }
case NodeKind::TupleTypeExpression: case NodeKind::TupleTypeExpression:
@ -842,7 +860,7 @@ namespace bolt {
for (auto [TE, Comma]: TupleTE->Elements) { for (auto [TE, Comma]: TupleTE->Elements) {
ElementTypes.push_back(inferTypeExpression(TE, IsPoly)); ElementTypes.push_back(inferTypeExpression(TE, IsPoly));
} }
auto Ty = new TTuple(ElementTypes); auto Ty = new Type(TTuple(ElementTypes));
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
} }
@ -863,7 +881,7 @@ namespace bolt {
ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly)); ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly));
} }
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly); auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly);
auto Ty = TArrow::build(ParamTypes, ReturnType); auto Ty = Type::buildArrow(ParamTypes, ReturnType);
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
} }
@ -886,14 +904,14 @@ namespace bolt {
} }
Type* sortRow(Type* Ty) { Type* sortRow(Type* Ty) {
std::map<ByteString, TField*> Fields; std::map<ByteString, Type*> Fields;
while (Ty->getKind() == TypeKind::Field) { while (Ty->isField()) {
auto Field = static_cast<TField*>(Ty); auto& Field = Ty->asField();
Fields.emplace(Field->Name, Field); Fields.emplace(Field.Name, Ty);
Ty = Field->RestTy; Ty = Field.RestTy;
} }
for (auto [Name, Field]: Fields) { for (auto [Name, Field]: Fields) {
Ty = new TField(Name, Field->Ty, Ty); Ty = new Type(TField(Name, Field->asField().Ty, Ty));
} }
return Ty; return Ty;
} }
@ -930,7 +948,7 @@ namespace bolt {
setContext(OldCtx); setContext(OldCtx);
} }
if (!Match->Value) { if (!Match->Value) {
Ty = new TArrow(ValTy, Ty); Ty = new Type(TArrow(ValTy, Ty));
} }
break; break;
} }
@ -938,9 +956,13 @@ namespace bolt {
case NodeKind::RecordExpression: case NodeKind::RecordExpression:
{ {
auto Record = static_cast<RecordExpression*>(X); auto Record = static_cast<RecordExpression*>(X);
Ty = new TNil(); Ty = new Type(TNil());
for (auto [Field, Comma]: Record->Fields) { for (auto [Field, Comma]: Record->Fields) {
Ty = new TField(Field->Name->getCanonicalText(), new TPresent(inferExpression(Field->getExpression())), Ty); Ty = new Type(TField(
Field->Name->getCanonicalText(),
new Type(TPresent(inferExpression(Field->getExpression()))),
Ty
));
} }
Ty = sortRow(Ty); Ty = sortRow(Ty);
break; break;
@ -998,7 +1020,7 @@ namespace bolt {
for (auto Arg: Call->Args) { for (auto Arg: Call->Args) {
ArgTypes.push_back(inferExpression(Arg)); ArgTypes.push_back(inferExpression(Arg));
} }
makeEqual(OpTy, TArrow::build(ArgTypes, Ty), X); makeEqual(OpTy, Type::buildArrow(ArgTypes, Ty), X);
break; break;
} }
@ -1008,14 +1030,15 @@ namespace bolt {
auto Scm = lookup(Infix->Operator->getText()); auto Scm = lookup(Infix->Operator->getText());
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator); DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator);
return createTypeVar(); Ty = createTypeVar();
break;
} }
auto OpTy = instantiate(Scm, Infix->Operator); auto OpTy = instantiate(Scm, Infix->Operator);
Ty = createTypeVar(); Ty = createTypeVar();
std::vector<Type*> ArgTys; std::vector<Type*> ArgTys;
ArgTys.push_back(inferExpression(Infix->Left)); ArgTys.push_back(inferExpression(Infix->Left));
ArgTys.push_back(inferExpression(Infix->Right)); ArgTys.push_back(inferExpression(Infix->Right));
makeEqual(TArrow::build(ArgTys, Ty), OpTy, X); makeEqual(Type::buildArrow(ArgTys, Ty), OpTy, X);
break; break;
} }
@ -1026,7 +1049,7 @@ namespace bolt {
for (auto [E, Comma]: Tuple->Elements) { for (auto [E, Comma]: Tuple->Elements) {
Types.push_back(inferExpression(E)); Types.push_back(inferExpression(E));
} }
Ty = new TTuple(Types); Ty = new Type(TTuple(Types));
break; break;
} }
@ -1038,7 +1061,7 @@ namespace bolt {
case NodeKind::IntegerLiteral: case NodeKind::IntegerLiteral:
{ {
auto I = static_cast<IntegerLiteral*>(Member->Name); auto I = static_cast<IntegerLiteral*>(Member->Name);
Ty = new TTupleIndex(ExprTy, I->getInteger()); Ty = new Type(TTupleIndex(ExprTy, I->getInteger()));
break; break;
} }
case NodeKind::Identifier: case NodeKind::Identifier:
@ -1046,7 +1069,7 @@ namespace bolt {
auto K = static_cast<Identifier*>(Member->Name); auto K = static_cast<Identifier*>(Member->Name);
Ty = createTypeVar(); Ty = createTypeVar();
auto RestTy = createTypeVar(); auto RestTy = createTypeVar();
makeEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member); makeEqual(new Type(TField(K->getCanonicalText(), Ty, RestTy)), ExprTy, Member);
break; break;
} }
default: default:
@ -1102,7 +1125,7 @@ namespace bolt {
} }
auto Ty = instantiate(Scm, P); auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar(); auto RetTy = createTypeVar();
makeEqual(Ty, TArrow::build(ParamTypes, RetTy), P); makeEqual(Ty, Type::buildArrow(ParamTypes, RetTy), P);
return RetTy; return RetTy;
} }
@ -1113,7 +1136,7 @@ namespace bolt {
for (auto [Element, Comma]: P->Elements) { for (auto [Element, Comma]: P->Elements) {
ElementTypes.push_back(inferPattern(Element)); ElementTypes.push_back(inferPattern(Element));
} }
return new TTuple(ElementTypes); return new Type(TTuple(ElementTypes));
} }
case NodeKind::ListPattern: case NodeKind::ListPattern:
@ -1123,7 +1146,7 @@ namespace bolt {
for (auto [Element, Separator]: P->Elements) { for (auto [Element, Separator]: P->Elements) {
makeEqual(ElementType, inferPattern(Element), P); makeEqual(ElementType, inferPattern(Element), P);
} }
return new TApp(ListType, ElementType); return new Type(TApp(ListType, ElementType));
} }
case NodeKind::NestedPattern: case NodeKind::NestedPattern:
@ -1204,7 +1227,14 @@ namespace bolt {
} }
Type* Checker::getType(TypedNode *Node) { Type* Checker::getType(TypedNode *Node) {
return Node->getType()->solve(); auto Ty = Node->getType();
if (Node->Flags & NodeFlags_TypeIsSolved) {
return Ty;
}
Ty = solveType(Ty);
Node->setType(Ty);
Node->Flags |= NodeFlags_TypeIsSolved;
return Ty;
} }
void Checker::check(SourceFile *SF) { void Checker::check(SourceFile *SF) {
@ -1217,11 +1247,11 @@ namespace bolt {
addBinding("True", new Forall(BoolType)); addBinding("True", new Forall(BoolType));
addBinding("False", new Forall(BoolType)); addBinding("False", new Forall(BoolType));
auto A = createTypeVar(); auto A = createTypeVar();
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, TArrow::build({ A, A }, BoolType))); addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)));
addBinding("+", new Forall(TArrow::build({ IntType, IntType }, IntType))); addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
addBinding("-", new Forall(TArrow::build({ IntType, IntType }, IntType))); addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
addBinding("*", new Forall(TArrow::build({ IntType, IntType }, IntType))); addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
addBinding("/", new Forall(TArrow::build({ IntType, IntType }, IntType))); addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
populate(SF); populate(SF);
forwardDeclare(SF); forwardDeclare(SF);
auto SCCs = RefGraph.strongconnect(); auto SCCs = RefGraph.strongconnect();
@ -1243,6 +1273,27 @@ namespace bolt {
ActiveContext = nullptr; ActiveContext = nullptr;
solve(new CMany(*SF->Ctx->Constraints)); solve(new CMany(*SF->Ctx->Constraints));
class Visitor : public CSTVisitor<Visitor> {
Checker& C;
public:
Visitor(Checker& C):
C(C) {}
void visitAnnotation(Annotation* A) {
}
void visitExpression(Expression* X) {
C.getType(X);
}
} V(*this);
V.visit(SF);
} }
void Checker::solve(Constraint* Constraint) { void Checker::solve(Constraint* Constraint) {
@ -1281,10 +1332,10 @@ namespace bolt {
} }
bool assignableTo(Type* A, Type* B) { bool assignableTo(Type* A, Type* B) {
if (isa<TCon>(A) && isa<TCon>(B)) { if (A->isCon() && B->isCon()) {
auto Con1 = cast<TCon>(A); auto& Con1 = A->asCon();
auto Con2 = cast<TCon>(B); auto& Con2 = B->asCon();
if (Con1->Id != Con2-> Id) { if (Con1.Id != Con2.Id) {
return false; return false;
} }
return true; return true;
@ -1295,13 +1346,15 @@ namespace bolt {
class ArrowCursor { class ArrowCursor {
std::stack<std::tuple<TArrow*, bool>> Stack; /// Types on this stack are guaranteed to be arrow types.
std::stack<std::tuple<Type*, bool>> Stack;
TypePath& Path; TypePath& Path;
std::size_t I; std::size_t I;
public: public:
ArrowCursor(TArrow* Arr, TypePath& Path): ArrowCursor(Type* Arr, TypePath& Path):
Path(Path) { Path(Path) {
Stack.push({ Arr, true }); Stack.push({ Arr, true });
Path.push_back(Arr->getStartIndex()); Path.push_back(Arr->getStartIndex());
@ -1323,9 +1376,9 @@ namespace bolt {
continue; continue;
} }
Ty = Arrow->resolve(Index); Ty = Arrow->resolve(Index);
if (isa<TArrow>(Ty)) { if (Ty->isArrow()) {
auto NewIndex = Arrow->getStartIndex(); auto NewIndex = Arrow->getStartIndex();
Stack.push({ static_cast<TArrow*>(Ty), true }); Stack.push({ Ty, true });
Path.push_back(NewIndex); Path.push_back(NewIndex);
} else { } else {
return Ty; return Ty;
@ -1390,40 +1443,36 @@ namespace bolt {
} }
TypeSig getTypeSig(Type* Ty) { TypeSig getTypeSig(Type* Ty) {
struct Visitor : TypeVisitor {
Type* Op = nullptr; Type* Op = nullptr;
std::vector<Type*> Args; std::vector<Type*> Args;
void visitType(Type* Ty) override { std::function<void(Type*)> Visit = [&](Type* Ty) {
if (!Op) { if (Ty->isApp()) {
Visit(Ty->asApp().Op);
Visit(Ty->asApp().Arg);
} else if (!Op) {
Op = Ty; Op = Ty;
} else { } else {
Args.push_back(Ty); Args.push_back(Ty);
} }
}
void visitAppType(TApp* Ty) override {
visitEachChild(Ty);
}
}; };
Visitor V; Visit(Ty);
V.visit(Ty); return TypeSig { Ty, Op, Args };
return TypeSig { Ty, V.Op, V.Args };
} }
void propagateClasses(std::unordered_set<TypeclassId>& Classes, Type* Ty) { void propagateClasses(std::unordered_set<TypeclassId>& Classes, Type* Ty) {
if (isa<TVar>(Ty)) { if (Ty->isVar()) {
auto TV = cast<TVar>(Ty); auto TV = Ty->asVar();
for (auto Class: Classes) { for (auto Class: Classes) {
TV->Contexts.emplace(Class); TV.Context.emplace(Class);
} }
if (TV->isRigid()) { if (TV.isRigid()) {
auto RV = static_cast<TVarRigid*>(Ty); for (auto Id: TV.Context) {
for (auto Id: RV->Contexts) { if (!TV.Provided->count(Id)) {
if (!RV->Provided.count(Id)) { C.DE.add<TypeclassMissingDiagnostic>(TypeclassSignature { Id, { Ty } }, getSource());
C.DE.add<TypeclassMissingDiagnostic>(TypeclassSignature { Id, { RV } }, getSource());
} }
} }
} }
} else if (isa<TCon>(Ty) || isa<TApp>(Ty)) { } else if (Ty->isCon() || Ty->isApp()) {
auto Sig = getTypeSig(Ty); auto Sig = getTypeSig(Ty);
for (auto Class: Classes) { for (auto Class: Classes) {
propagateClassTycon(Class, Sig); propagateClassTycon(Class, Sig);
@ -1450,13 +1499,13 @@ namespace bolt {
* *
* Other side effects may occur. * Other side effects may occur.
*/ */
void join(TVar* TV, Type* Ty) { void join(Type* TV, Type* Ty) {
// std::cerr << describe(TV) << " => " << describe(Ty) << std::endl; // std::cerr << describe(TV) << " => " << describe(Ty) << std::endl;
TV->set(Ty); TV->set(Ty);
propagateClasses(TV->Contexts, Ty); propagateClasses(TV->asVar().Context, Ty);
// This is a very specific adjustment that is critical to the // This is a very specific adjustment that is critical to the
// well-functioning of the infer/unify algorithm. When addConstraint() is // well-functioning of the infer/unify algorithm. When addConstraint() is
@ -1480,21 +1529,21 @@ namespace bolt {
}; };
bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) { bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) {
if (isa<TAbsent>(A) && isa<TAbsent>(B)) { if (A->isAbsent() && B->isAbsent()) {
return true; return true;
} }
if (isa<TAbsent>(B)) { if (B->isAbsent()) {
std::swap(A, B); std::swap(A, B);
DidSwap = !DidSwap; DidSwap = !DidSwap;
} }
if (isa<TAbsent>(A)) { if (A->isAbsent()) {
auto Present = static_cast<TPresent*>(B); auto& Present = B->asPresent();
C.DE.add<FieldNotFoundDiagnostic>(CurrentFieldName, C.simplifyType(getLeft()), LeftPath, getSource()); C.DE.add<FieldNotFoundDiagnostic>(CurrentFieldName, C.solveType(getLeft()), LeftPath, getSource());
return false; return false;
} }
auto Present1 = static_cast<TPresent*>(A); auto& Present1 = A->asPresent();
auto Present2 = static_cast<TPresent*>(B); auto& Present2 = B->asPresent();
return unify(Present1->Ty, Present2->Ty, DidSwap); return unify(Present1.Ty, Present2.Ty, DidSwap);
}; };
bool Unifier::unify(Type* A, Type* B, bool DidSwap) { bool Unifier::unify(Type* A, Type* B, bool DidSwap) {
@ -1504,8 +1553,8 @@ namespace bolt {
auto unifyError = [&]() { auto unifyError = [&]() {
C.DE.add<UnificationErrorDiagnostic>( C.DE.add<UnificationErrorDiagnostic>(
C.simplifyType(Constraint->Left), Constraint->Left,
C.simplifyType(Constraint->Right), Constraint->Right,
LeftPath, LeftPath,
RightPath, RightPath,
Constraint->Source Constraint->Source
@ -1549,50 +1598,50 @@ namespace bolt {
DidSwap = !DidSwap; DidSwap = !DidSwap;
}; };
if (isa<TVar>(A) && isa<TVar>(B)) { if (A->isVar() && B->isVar()) {
auto Var1 = static_cast<TVar*>(A); auto& Var1 = A->asVar();
auto Var2 = static_cast<TVar*>(B); auto& Var2 = B->asVar();
if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) { if (Var1.isRigid() && Var2.isRigid()) {
if (Var1->Id != Var2->Id) { if (Var1.Id != Var2.Id) {
unifyError(); unifyError();
return false; return false;
} }
return true; return true;
} }
TVar* To; Type* To;
TVar* From; Type* From;
if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) { if (Var1.isRigid() && Var2.isUni()) {
To = Var1; To = A;
From = Var2; From = B;
} else { } else {
// Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification
// Either way, Var1, being Unification, is a good candidate for being unified away // Either way, Var1, being Unification, is a good candidate for being unified away
To = Var2; To = B;
From = Var1; From = A;
} }
if (From->Id != To->Id) { if (From->asVar().Id != To->asVar().Id) {
join(From, To); join(From, To);
} }
return true; return true;
} }
if (isa<TVar>(B)) { if (B->isVar()) {
swap(); swap();
} }
if (isa<TVar>(A)) { if (A->isVar()) {
auto TV = static_cast<TVar*>(A); auto& TV = A->asVar();
// Rigid type variables can never unify with antything else than what we // Rigid type variables can never unify with antything else than what we
// have already handled in the previous if-statement, so issue an error. // have already handled in the previous if-statement, so issue an error.
if (TV->getVarKind() == VarKind::Rigid) { if (TV.isRigid()) {
unifyError(); unifyError();
return false; return false;
} }
// Occurs check // Occurs check
if (B->hasTypeVar(TV)) { if (B->hasTypeVar(A)) {
// NOTE Just like GHC, we just display an error message indicating that // NOTE Just like GHC, we just display an error message indicating that
// A cannot match B, e.g. a cannot match [a]. It looks much better // A cannot match B, e.g. a cannot match [a]. It looks much better
// than obsure references to an occurs check // than obsure references to an occurs check
@ -1600,25 +1649,25 @@ namespace bolt {
return false; return false;
} }
join(TV, B); join(A, B);
return true; return true;
} }
if (isa<TArrow>(A) && isa<TArrow>(B)) { if (A->isArrow() && B->isArrow()) {
auto Arrow1 = static_cast<TArrow*>(A); auto& Arrow1 = A->asArrow();
auto Arrow2 = static_cast<TArrow*>(B); auto& Arrow2 = B->asArrow();
bool Success = true; bool Success = true;
LeftPath.push_back(TypeIndex::forArrowParamType()); LeftPath.push_back(TypeIndex::forArrowParamType());
RightPath.push_back(TypeIndex::forArrowParamType()); RightPath.push_back(TypeIndex::forArrowParamType());
if (!unify(Arrow1->ParamType, Arrow2->ParamType, DidSwap)) { if (!unify(Arrow1.ParamType, Arrow2.ParamType, DidSwap)) {
Success = false; Success = false;
} }
LeftPath.pop_back(); LeftPath.pop_back();
RightPath.pop_back(); RightPath.pop_back();
LeftPath.push_back(TypeIndex::forArrowReturnType()); LeftPath.push_back(TypeIndex::forArrowReturnType());
RightPath.push_back(TypeIndex::forArrowReturnType()); RightPath.push_back(TypeIndex::forArrowReturnType());
if (!unify(Arrow1->ReturnType, Arrow2->ReturnType, DidSwap)) { if (!unify(Arrow1.ReturnType, Arrow2.ReturnType, DidSwap)) {
Success = false; Success = false;
} }
LeftPath.pop_back(); LeftPath.pop_back();
@ -1626,20 +1675,20 @@ namespace bolt {
return Success; return Success;
} }
if (isa<TApp>(A) && isa<TApp>(B)) { if (A->isApp() && B->isApp()) {
auto App1 = static_cast<TApp*>(A); auto& App1 = A->asApp();
auto App2 = static_cast<TApp*>(B); auto& App2 = B->asApp();
bool Success = true; bool Success = true;
LeftPath.push_back(TypeIndex::forAppOpType()); LeftPath.push_back(TypeIndex::forAppOpType());
RightPath.push_back(TypeIndex::forAppOpType()); RightPath.push_back(TypeIndex::forAppOpType());
if (!unify(App1->Op, App2->Op, DidSwap)) { if (!unify(App1.Op, App2.Op, DidSwap)) {
Success = false; Success = false;
} }
LeftPath.pop_back(); LeftPath.pop_back();
RightPath.pop_back(); RightPath.pop_back();
LeftPath.push_back(TypeIndex::forAppArgType()); LeftPath.push_back(TypeIndex::forAppArgType());
RightPath.push_back(TypeIndex::forAppArgType()); RightPath.push_back(TypeIndex::forAppArgType());
if (!unify(App1->Arg, App2->Arg, DidSwap)) { if (!unify(App1.Arg, App2.Arg, DidSwap)) {
Success = false; Success = false;
} }
LeftPath.pop_back(); LeftPath.pop_back();
@ -1647,19 +1696,19 @@ namespace bolt {
return Success; return Success;
} }
if (isa<TTuple>(A) && isa<TTuple>(B)) { if (A->isTuple() && B->isTuple()) {
auto Tuple1 = static_cast<TTuple*>(A); auto& Tuple1 = A->asTuple();
auto Tuple2 = static_cast<TTuple*>(B); auto& Tuple2 = B->asTuple();
if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) { if (Tuple1.ElementTypes.size() != Tuple2.ElementTypes.size()) {
unifyError(); unifyError();
return false; return false;
} }
auto Count = Tuple1->ElementTypes.size(); auto Count = Tuple1.ElementTypes.size();
bool Success = true; bool Success = true;
for (size_t I = 0; I < Count; I++) { for (size_t I = 0; I < Count; I++) {
LeftPath.push_back(TypeIndex::forTupleElement(I)); LeftPath.push_back(TypeIndex::forTupleElement(I));
RightPath.push_back(TypeIndex::forTupleElement(I)); RightPath.push_back(TypeIndex::forTupleElement(I));
if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], DidSwap)) { if (!unify(Tuple1.ElementTypes[I], Tuple2.ElementTypes[I], DidSwap)) {
Success = false; Success = false;
} }
LeftPath.pop_back(); LeftPath.pop_back();
@ -1668,84 +1717,85 @@ namespace bolt {
return Success; return Success;
} }
if (isa<TTupleIndex>(A) || isa<TTupleIndex>(B)) { if (A->isTupleIndex() || B->isTupleIndex()) {
// Type(s) could not be simplified at the beginning of this function, // Type(s) could not be simplified at the beginning of this function,
// so we have to re-visit the constraint when there is more information. // so we have to re-visit the constraint when there is more information.
C.Queue.push_back(Constraint); C.Queue.push_back(Constraint);
return true; return true;
} }
// if (isa<TTupleIndex>(A) && isa<TTupleIndex>(B)) { // This does not work because it ignores the indices
// if (A->isTupleIndex() && B->isTupleIndex()) {
// auto Index1 = static_cast<TTupleIndex*>(A); // auto Index1 = static_cast<TTupleIndex*>(A);
// auto Index2 = static_cast<TTupleIndex*>(B); // auto Index2 = static_cast<TTupleIndex*>(B);
// return unify(Index1->Ty, Index2->Ty, Source); // return unify(Index1->Ty, Index2->Ty, Source);
// } // }
if (isa<TCon>(A) && isa<TCon>(B)) { if (A->isCon() && B->isCon()) {
auto Con1 = static_cast<TCon*>(A); auto& Con1 = A->asCon();
auto Con2 = static_cast<TCon*>(B); auto& Con2 = B->asCon();
if (Con1->Id != Con2->Id) { if (Con1.Id != Con2.Id) {
unifyError(); unifyError();
return false; return false;
} }
return true; return true;
} }
if (isa<TNil>(A) && isa<TNil>(B)) { if (A->isNil() && B->isNil()) {
return true; return true;
} }
if (isa<TField>(A) && isa<TField>(B)) { if (A->isField() && B->isField()) {
auto Field1 = static_cast<TField*>(A); auto& Field1 = A->asField();
auto Field2 = static_cast<TField*>(B); auto& Field2 = B->asField();
bool Success = true; bool Success = true;
if (Field1->Name == Field2->Name) { if (Field1.Name == Field2.Name) {
LeftPath.push_back(TypeIndex::forFieldType()); LeftPath.push_back(TypeIndex::forFieldType());
RightPath.push_back(TypeIndex::forFieldType()); RightPath.push_back(TypeIndex::forFieldType());
CurrentFieldName = Field1->Name; CurrentFieldName = Field1.Name;
if (!unifyField(Field1->Ty, Field2->Ty, DidSwap)) { if (!unifyField(Field1.Ty, Field2.Ty, DidSwap)) {
Success = false; Success = false;
} }
LeftPath.pop_back(); LeftPath.pop_back();
RightPath.pop_back(); RightPath.pop_back();
LeftPath.push_back(TypeIndex::forFieldRest()); LeftPath.push_back(TypeIndex::forFieldRest());
RightPath.push_back(TypeIndex::forFieldRest()); RightPath.push_back(TypeIndex::forFieldRest());
if (!unify(Field1->RestTy, Field2->RestTy, DidSwap)) { if (!unify(Field1.RestTy, Field2.RestTy, DidSwap)) {
Success = false; Success = false;
} }
LeftPath.pop_back(); LeftPath.pop_back();
RightPath.pop_back(); RightPath.pop_back();
return Success; return Success;
} }
auto NewRestTy = new TVar(C.NextTypeVarId++, VarKind::Unification); auto NewRestTy = new Type(TVar(VarKind::Unification, C.NextTypeVarId++));
pushLeft(TypeIndex::forFieldRest()); pushLeft(TypeIndex::forFieldRest());
if (!unify(Field1->RestTy, new TField(Field2->Name, Field2->Ty, NewRestTy), DidSwap)) { if (!unify(Field1.RestTy, new Type(TField(Field2.Name, Field2.Ty, NewRestTy)), DidSwap)) {
Success = false; Success = false;
} }
popLeft(); popLeft();
pushRight(TypeIndex::forFieldRest()); pushRight(TypeIndex::forFieldRest());
if (!unify(new TField(Field1->Name, Field1->Ty, NewRestTy), Field2->RestTy, DidSwap)) { if (!unify(new Type(TField(Field1.Name, Field1.Ty, NewRestTy)), Field2.RestTy, DidSwap)) {
Success = false; Success = false;
} }
popRight(); popRight();
return Success; return Success;
} }
if (isa<TNil>(A) && isa<TField>(B)) { if (A->isNil() && B->isField()) {
swap(); swap();
} }
if (isa<TField>(A) && isa<TNil>(B)) { if (A->isField() && B->isNil()) {
auto Field = static_cast<TField*>(A); auto& Field = A->asField();
bool Success = true; bool Success = true;
pushLeft(TypeIndex::forFieldType()); pushLeft(TypeIndex::forFieldType());
CurrentFieldName = Field->Name; CurrentFieldName = Field.Name;
if (!unifyField(Field->Ty, new TAbsent, DidSwap)) { if (!unifyField(Field.Ty, new Type(TAbsent()), DidSwap)) {
Success = false; Success = false;
} }
popLeft(); popLeft();
pushLeft(TypeIndex::forFieldRest()); pushLeft(TypeIndex::forFieldRest());
if (!unify(Field->RestTy, B, DidSwap)) { if (!unify(Field.RestTy, B, DidSwap)) {
Success = false; Success = false;
} }
popLeft(); popLeft();
@ -1762,6 +1812,5 @@ namespace bolt {
A.unify(); A.unify();
} }
} }

View file

@ -193,41 +193,42 @@ namespace bolt {
} }
std::string describe(const Type* Ty) { std::string describe(const Type* Ty) {
Ty = Ty->find();
switch (Ty->getKind()) { switch (Ty->getKind()) {
case TypeKind::Var: case TypeKind::Var:
{ {
auto TV = static_cast<const TVar*>(Ty); auto TV = Ty->asVar();
if (TV->getVarKind() == VarKind::Rigid) { if (TV.isRigid()) {
return static_cast<const TVarRigid*>(TV)->Name; return *TV.Name;
} }
return "a" + std::to_string(TV->Id); return "a" + std::to_string(TV.Id);
} }
case TypeKind::Arrow: case TypeKind::Arrow:
{ {
auto Y = static_cast<const TArrow*>(Ty); auto Y = Ty->asArrow();
std::ostringstream Out; std::ostringstream Out;
Out << describe(Y->ParamType) << " -> " << describe(Y->ReturnType); Out << describe(Y.ParamType) << " -> " << describe(Y.ReturnType);
return Out.str(); return Out.str();
} }
case TypeKind::Con: case TypeKind::Con:
{ {
auto Y = static_cast<const TCon*>(Ty); auto Y = Ty->asCon();
return Y->DisplayName; return Y.DisplayName;
} }
case TypeKind::App: case TypeKind::App:
{ {
auto Y = static_cast<const TApp*>(Ty); auto Y = Ty->asApp();
return describe(Y->Op) + " " + describe(Y->Arg); return describe(Y.Op) + " " + describe(Y.Arg);
} }
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
std::ostringstream Out; std::ostringstream Out;
auto Y = static_cast<const TTuple*>(Ty); auto Y = Ty->asTuple();
Out << "("; Out << "(";
if (Y->ElementTypes.size()) { if (Y.ElementTypes.size()) {
auto Iter = Y->ElementTypes.begin(); auto Iter = Y.ElementTypes.begin();
Out << describe(*Iter++); Out << describe(*Iter++);
while (Iter != Y->ElementTypes.end()) { while (Iter != Y.ElementTypes.end()) {
Out << ", " << describe(*Iter++); Out << ", " << describe(*Iter++);
} }
} }
@ -236,8 +237,8 @@ namespace bolt {
} }
case TypeKind::TupleIndex: case TypeKind::TupleIndex:
{ {
auto Y = static_cast<const TTupleIndex*>(Ty); auto Y = Ty->asTupleIndex();
return describe(Y->Ty) + "." + std::to_string(Y->I); return describe(Y.Ty) + "." + std::to_string(Y.I);
} }
case TypeKind::Nil: case TypeKind::Nil:
return "{}"; return "{}";
@ -245,19 +246,19 @@ namespace bolt {
return "Abs"; return "Abs";
case TypeKind::Present: case TypeKind::Present:
{ {
auto Y = static_cast<const TPresent*>(Ty); auto Y = Ty->asPresent();
return describe(Y->Ty); return describe(Y.Ty);
} }
case TypeKind::Field: case TypeKind::Field:
{ {
auto Y = static_cast<const TField*>(Ty); auto Y = Ty->asField();
std::ostringstream out; std::ostringstream out;
out << "{ " << Y->Name << ": " << describe(Y->Ty); out << "{ " << Y.Name << ": " << describe(Y.Ty);
Ty = Y->RestTy; Ty = Y.RestTy;
while (Ty->getKind() == TypeKind::Field) { while (Ty->getKind() == TypeKind::Field) {
auto Y = static_cast<const TField*>(Ty); auto Y = Ty->asField();
out << "; " + Y->Name + ": " + describe(Y->Ty); out << "; " + Y.Name + ": " + describe(Y.Ty);
Ty = Y->RestTy; Ty = Y.RestTy;
} }
if (Ty->getKind() != TypeKind::Nil) { if (Ty->getKind() != TypeKind::Nil) {
out << "; " + describe(Ty); out << "; " + describe(Ty);
@ -561,53 +562,52 @@ namespace bolt {
void exitType(const Type* Ty) override { void exitType(const Type* Ty) override {
if (shouldUnderline()) { if (shouldUnderline()) {
W.setUnderline(false); W.setUnderline(false); // FIXME Should set to old value
} }
} }
void visitAppType(const TApp *Ty) override { void visitAppType(const TApp& Ty) override {
auto Y = static_cast<const TApp*>(Ty);
Path.push_back(TypeIndex::forAppOpType()); Path.push_back(TypeIndex::forAppOpType());
visit(Y->Op); visit(Ty.Op);
Path.pop_back(); Path.pop_back();
W.write(" "); W.write(" ");
Path.push_back(TypeIndex::forAppArgType()); Path.push_back(TypeIndex::forAppArgType());
visit(Y->Arg); visit(Ty.Arg);
Path.pop_back(); Path.pop_back();
} }
void visitVarType(const TVar* Ty) override { void visitVarType(const TVar& Ty) override {
if (Ty->getVarKind() == VarKind::Rigid) { if (Ty.isRigid()) {
W.write(static_cast<const TVarRigid*>(Ty)->Name); W.write(*Ty.Name);
return; return;
} }
W.write("a"); W.write("a");
W.write(Ty->Id); W.write(Ty.Id);
} }
void visitConType(const TCon *Ty) override { void visitConType(const TCon& Ty) override {
W.write(Ty->DisplayName); W.write(Ty.DisplayName);
} }
void visitArrowType(const TArrow* Ty) override { void visitArrowType(const TArrow& Ty) override {
Path.push_back(TypeIndex::forArrowParamType()); Path.push_back(TypeIndex::forArrowParamType());
visit(Ty->ParamType); visit(Ty.ParamType);
Path.pop_back(); Path.pop_back();
W.write(" -> "); W.write(" -> ");
Path.push_back(TypeIndex::forArrowReturnType()); Path.push_back(TypeIndex::forArrowReturnType());
visit(Ty->ReturnType); visit(Ty.ReturnType);
Path.pop_back(); Path.pop_back();
} }
void visitTupleType(const TTuple *Ty) override { void visitTupleType(const TTuple& Ty) override {
W.write("("); W.write("(");
if (Ty->ElementTypes.size()) { if (Ty.ElementTypes.size()) {
auto Iter = Ty->ElementTypes.begin(); auto Iter = Ty.ElementTypes.begin();
Path.push_back(TypeIndex::forTupleElement(0)); Path.push_back(TypeIndex::forTupleElement(0));
visit(*Iter++); visit(*Iter++);
Path.pop_back(); Path.pop_back();
std::size_t I = 1; std::size_t I = 1;
while (Iter != Ty->ElementTypes.end()) { while (Iter != Ty.ElementTypes.end()) {
W.write(", "); W.write(", ");
Path.push_back(TypeIndex::forTupleElement(I++)); Path.push_back(TypeIndex::forTupleElement(I++));
visit(*Iter++); visit(*Iter++);
@ -617,47 +617,47 @@ namespace bolt {
W.write(")"); W.write(")");
} }
void visitTupleIndexType(const TTupleIndex *Ty) override { void visitTupleIndexType(const TTupleIndex& Ty) override {
Path.push_back(TypeIndex::forTupleIndexType()); Path.push_back(TypeIndex::forTupleIndexType());
visit(Ty->Ty); visit(Ty.Ty);
Path.pop_back(); Path.pop_back();
W.write("."); W.write(".");
W.write(Ty->I); W.write(Ty.I);
} }
void visitNilType(const TNil *Ty) override { void visitNilType(const TNil& Ty) override {
W.write("{}"); W.write("{}");
} }
void visitAbsentType(const TAbsent *Ty) override { void visitAbsentType(const TAbsent& Ty) override {
W.write("Abs"); W.write("Abs");
} }
void visitPresentType(const TPresent *Ty) override { void visitPresentType(const TPresent& Ty) override {
Path.push_back(TypeIndex::forPresentType()); Path.push_back(TypeIndex::forPresentType());
visit(Ty->Ty); visit(Ty.Ty);
Path.pop_back(); Path.pop_back();
} }
void visitFieldType(const TField* Ty) override { void visitFieldType(const TField& Ty) override {
W.write("{ "); W.write("{ ");
W.write(Ty->Name); W.write(Ty.Name);
W.write(": "); W.write(": ");
Path.push_back(TypeIndex::forFieldType()); Path.push_back(TypeIndex::forFieldType());
visit(Ty->Ty); visit(Ty.Ty);
Path.pop_back(); Path.pop_back();
auto Ty2 = Ty->RestTy; auto Ty2 = Ty.RestTy;
Path.push_back(TypeIndex::forFieldRest()); Path.push_back(TypeIndex::forFieldRest());
std::size_t I = 1; std::size_t I = 1;
while (Ty2->getKind() == TypeKind::Field) { while (Ty2->isField()) {
auto Y = static_cast<const TField*>(Ty2); auto Y = Ty2->asField();
W.write("; "); W.write("; ");
W.write(Y->Name); W.write(Y.Name);
W.write(": "); W.write(": ");
Path.push_back(TypeIndex::forFieldType()); Path.push_back(TypeIndex::forFieldType());
visit(Y->Ty); visit(Y.Ty);
Path.pop_back(); Path.pop_back();
Ty2 = Y->RestTy; Ty2 = Y.RestTy;
Path.push_back(TypeIndex::forFieldRest()); Path.push_back(TypeIndex::forFieldRest());
++I; ++I;
} }
@ -730,7 +730,7 @@ namespace bolt {
case DiagnosticKind::BindingNotFound: case DiagnosticKind::BindingNotFound:
{ {
auto E = static_cast<const BindingNotFoundDiagnostic&>(D); auto& E = static_cast<const BindingNotFoundDiagnostic&>(D);
writePrefix(E); writePrefix(E);
write("binding "); write("binding ");
writeBinding(E.Name); writeBinding(E.Name);
@ -746,7 +746,7 @@ namespace bolt {
case DiagnosticKind::UnexpectedToken: case DiagnosticKind::UnexpectedToken:
{ {
auto E = static_cast<const UnexpectedTokenDiagnostic&>(D); auto& E = static_cast<const UnexpectedTokenDiagnostic&>(D);
writePrefix(E); writePrefix(E);
writeLoc(E.File, E.Actual->getStartLoc()); writeLoc(E.File, E.Actual->getStartLoc());
write(" expected "); write(" expected ");
@ -780,7 +780,7 @@ namespace bolt {
case DiagnosticKind::UnexpectedString: case DiagnosticKind::UnexpectedString:
{ {
auto E = static_cast<const UnexpectedStringDiagnostic&>(D); auto& E = static_cast<const UnexpectedStringDiagnostic&>(D);
writePrefix(E); writePrefix(E);
writeLoc(E.File, E.Location); writeLoc(E.File, E.Location);
write(" unexpected '"); write(" unexpected '");
@ -806,7 +806,7 @@ namespace bolt {
case DiagnosticKind::UnificationError: case DiagnosticKind::UnificationError:
{ {
auto E = static_cast<const UnificationErrorDiagnostic&>(D); auto& E = static_cast<const UnificationErrorDiagnostic&>(D);
auto Left = E.OrigLeft->resolve(E.LeftPath); auto Left = E.OrigLeft->resolve(E.LeftPath);
auto Right = E.OrigRight->resolve(E.RightPath); auto Right = E.OrigRight->resolve(E.RightPath);
writePrefix(E); writePrefix(E);
@ -857,7 +857,7 @@ namespace bolt {
case DiagnosticKind::TypeclassMissing: case DiagnosticKind::TypeclassMissing:
{ {
auto E = static_cast<const TypeclassMissingDiagnostic&>(D); auto& E = static_cast<const TypeclassMissingDiagnostic&>(D);
writePrefix(E); writePrefix(E);
write("the type class "); write("the type class ");
writeTypeclassSignature(E.Sig); writeTypeclassSignature(E.Sig);
@ -869,7 +869,7 @@ namespace bolt {
case DiagnosticKind::InstanceNotFound: case DiagnosticKind::InstanceNotFound:
{ {
auto E = static_cast<const InstanceNotFoundDiagnostic&>(D); auto& E = static_cast<const InstanceNotFoundDiagnostic&>(D);
writePrefix(E); writePrefix(E);
write("a type class instance "); write("a type class instance ");
writeTypeclassName(E.TypeclassName); writeTypeclassName(E.TypeclassName);
@ -883,7 +883,7 @@ namespace bolt {
case DiagnosticKind::TupleIndexOutOfRange: case DiagnosticKind::TupleIndexOutOfRange:
{ {
auto E = static_cast<const TupleIndexOutOfRangeDiagnostic&>(D); auto& E = static_cast<const TupleIndexOutOfRangeDiagnostic&>(D);
writePrefix(E); writePrefix(E);
write("the index "); write("the index ");
writeType(E.I); writeType(E.I);
@ -894,7 +894,7 @@ namespace bolt {
case DiagnosticKind::InvalidTypeToTypeclass: case DiagnosticKind::InvalidTypeToTypeclass:
{ {
auto E = static_cast<const InvalidTypeToTypeclassDiagnostic&>(D); auto& E = static_cast<const InvalidTypeToTypeclassDiagnostic&>(D);
writePrefix(E); writePrefix(E);
write("the type "); write("the type ");
writeType(E.Actual); writeType(E.Actual);
@ -911,7 +911,7 @@ namespace bolt {
case DiagnosticKind::FieldNotFound: case DiagnosticKind::FieldNotFound:
{ {
auto E = static_cast<const FieldNotFoundDiagnostic&>(D); auto& E = static_cast<const FieldNotFoundDiagnostic&>(D);
writePrefix(E); writePrefix(E);
write("the field '"); write("the field '");
write(E.Name); write(E.Name);
@ -921,6 +921,16 @@ namespace bolt {
break; break;
} }
case DiagnosticKind::NotATuple:
{
auto& E = static_cast<const NotATupleDiagnostic&>(D);
writePrefix(E);
write("the type ");
writeType(E.Ty);
write(" is not a tuple.\n");
break;
}
} }
} }

View file

@ -1,9 +1,8 @@
#include "zen/config.hpp"
#include "zen/range.hpp"
#include "bolt/Common.hpp"
#include "bolt/Type.hpp" #include "bolt/Type.hpp"
#include <cwchar>
#include <sys/wait.h>
#include <vector>
namespace bolt { namespace bolt {
@ -13,13 +12,13 @@ namespace bolt {
} }
ZEN_ASSERT(Params.size() == 1); ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1); ZEN_ASSERT(Other.Params.size() == 1);
return Params[0]->Id < Other.Params[0]->Id; return Params[0]->asCon().Id < Other.Params[0]->asCon().Id;
} }
bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { bool TypeclassSignature::operator==(const TypeclassSignature& Other) const {
ZEN_ASSERT(Params.size() == 1); ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1); ZEN_ASSERT(Other.Params.size() == 1);
return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id; return Id == Other.Id && Params[0]->asCon().Id == Other.Params[0]->asCon().Id;
} }
bool TypeIndex::operator==(const TypeIndex& Other) const noexcept { bool TypeIndex::operator==(const TypeIndex& Other) const noexcept {
@ -35,34 +34,120 @@ namespace bolt {
} }
} }
void TypeIndex::advance(const Type* Ty) { bool TCon::operator==(const TCon& Other) const {
switch (Kind) { return Id == Other.Id;
case TypeIndexKind::End:
break;
case TypeIndexKind::AppOpType:
Kind = TypeIndexKind::AppArgType;
break;
case TypeIndexKind::ArrowParamType:
Kind = TypeIndexKind::ArrowReturnType;
break;
case TypeIndexKind::ArrowReturnType:
Kind = TypeIndexKind::End;
break;
case TypeIndexKind::FieldType:
Kind = TypeIndexKind::FieldRestType;
break;
case TypeIndexKind::FieldRestType:
case TypeIndexKind::TupleIndexType:
case TypeIndexKind::PresentType:
case TypeIndexKind::AppArgType:
case TypeIndexKind::TupleElement:
{
auto Tuple = cast<TTuple>(Ty);
if (I+1 < Tuple->ElementTypes.size()) {
++I;
} else {
Kind = TypeIndexKind::End;
} }
bool TApp::operator==(const TApp& Other) const {
return *Op == *Other.Op && *Arg == *Other.Arg;
}
bool TVar::operator==(const TVar& Other) const {
return Id == Other.Id;
}
bool TArrow::operator==(const TArrow& Other) const {
return *ParamType == *Other.ParamType
&& *ReturnType == *Other.ReturnType;
}
bool TTuple::operator==(const TTuple& Other) const {
for (auto [T1, T2]: zen::zip(ElementTypes, Other.ElementTypes)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
bool TTupleIndex::operator==(const TTupleIndex& Other) const {
return *Ty == *Other.Ty && I == Other.I;
}
bool TNil::operator==(const TNil& Other) const {
return true;
}
bool TField::operator==(const TField& Other) const {
return Name == Other.Name && *Ty == *Other.Ty && *RestTy == *Other.RestTy;
}
bool TAbsent::operator==(const TAbsent& Other) const {
return true;
}
bool TPresent::operator==(const TPresent& Other) const {
return *Ty == *Other.Ty;
}
bool Type::operator==(const Type& Other) const {
if (Kind != Other.Kind) {
return false;
}
switch (Kind) {
case TypeKind::Var:
return Var == Other.Var;
case TypeKind::Con:
return Con == Other.Con;
case TypeKind::Present:
return Present == Other.Present;
case TypeKind::Absent:
return Absent == Other.Absent;
case TypeKind::Arrow:
return Arrow == Other.Arrow;
case TypeKind::Field:
return Field == Other.Field;
case TypeKind::Nil:
return Nil == Other.Nil;
case TypeKind::Tuple:
return Tuple == Other.Tuple;
case TypeKind::TupleIndex:
return TupleIndex == Other.TupleIndex;
case TypeKind::App:
return App == Other.App;
}
}
void Type::visitEachChild(std::function<void(Type*)> Proc) {
switch (Kind) {
case TypeKind::Var:
case TypeKind::Absent:
case TypeKind::Nil:
case TypeKind::Con:
break;
case TypeKind::Arrow:
{
Proc(Arrow.ParamType);
Proc(Arrow.ReturnType);
break;
}
case TypeKind::Tuple:
{
for (auto I = 0; I < Tuple.ElementTypes.size(); ++I) {
Proc(Tuple.ElementTypes[I]);
}
break;
}
case TypeKind::App:
{
Proc(App.Op);
Proc(App.Arg);
break;
}
case TypeKind::Field:
{
Proc(Field.Ty);
Proc(Field.RestTy);
break;
}
case TypeKind::Present:
{
Proc(Present.Ty);
break;
}
case TypeKind::TupleIndex:
{
Proc(TupleIndex.Ty);
break; break;
} }
} }
@ -81,49 +166,49 @@ namespace bolt {
return Ty2; return Ty2;
case TypeKind::Arrow: case TypeKind::Arrow:
{ {
auto Arrow = static_cast<TArrow*>(Ty2); auto Arrow = Ty2->asArrow();
bool Changed = false; bool Changed = false;
Type* NewParamType = Arrow->ParamType->rewrite(Fn); Type* NewParamType = Arrow.ParamType->rewrite(Fn, Recursive);
if (NewParamType != Arrow->ParamType) { if (NewParamType != Arrow.ParamType) {
Changed = true; Changed = true;
} }
auto NewRetTy = Arrow->ReturnType->rewrite(Fn); auto NewRetTy = Arrow.ReturnType->rewrite(Fn, Recursive);
if (NewRetTy != Arrow->ReturnType) { if (NewRetTy != Arrow.ReturnType) {
Changed = true; Changed = true;
} }
return Changed ? new TArrow(NewParamType, NewRetTy) : Ty2; return Changed ? new Type(TArrow(NewParamType, NewRetTy)) : Ty2;
} }
case TypeKind::Con: case TypeKind::Con:
return Ty2; return Ty2;
case TypeKind::App: case TypeKind::App:
{ {
auto App = static_cast<TApp*>(Ty2); auto App = Ty2->asApp();
auto NewOp = App->Op->rewrite(Fn); auto NewOp = App.Op->rewrite(Fn, Recursive);
auto NewArg = App->Arg->rewrite(Fn); auto NewArg = App.Arg->rewrite(Fn, Recursive);
if (NewOp == App->Op && NewArg == App->Arg) { if (NewOp == App.Op && NewArg == App.Arg) {
return App; return Ty2;
} }
return new TApp(NewOp, NewArg); return new Type(TApp(NewOp, NewArg));
} }
case TypeKind::TupleIndex: case TypeKind::TupleIndex:
{ {
auto Tuple = static_cast<TTupleIndex*>(Ty2); auto Tuple = Ty2->asTupleIndex();
auto NewTy = Tuple->Ty->rewrite(Fn); auto NewTy = Tuple.Ty->rewrite(Fn, Recursive);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple; return NewTy != Tuple.Ty ? new Type(TTupleIndex(NewTy, Tuple.I)) : Ty2;
} }
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
auto Tuple = static_cast<TTuple*>(Ty2); auto Tuple = Ty2->asTuple();
bool Changed = false; bool Changed = false;
std::vector<Type*> NewElementTypes; std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) { for (auto Ty: Tuple.ElementTypes) {
auto NewElementType = Ty->rewrite(Fn); auto NewElementType = Ty->rewrite(Fn, Recursive);
if (NewElementType != Ty) { if (NewElementType != Ty) {
Changed = true; Changed = true;
} }
NewElementTypes.push_back(NewElementType); NewElementTypes.push_back(NewElementType);
} }
return Changed ? new TTuple(NewElementTypes) : Ty2; return Changed ? new Type(TTuple(NewElementTypes)) : Ty2;
} }
case TypeKind::Nil: case TypeKind::Nil:
return Ty2; return Ty2;
@ -131,272 +216,77 @@ namespace bolt {
return Ty2; return Ty2;
case TypeKind::Field: case TypeKind::Field:
{ {
auto Field = static_cast<TField*>(Ty2); auto Field = Ty2->asField();
bool Changed = false; bool Changed = false;
auto NewTy = Field->Ty->rewrite(Fn); auto NewTy = Field.Ty->rewrite(Fn, Recursive);
if (NewTy != Field->Ty) { if (NewTy != Field.Ty) {
Changed = true; Changed = true;
} }
auto NewRestTy = Field->RestTy->rewrite(Fn); auto NewRestTy = Field.RestTy->rewrite(Fn, Recursive);
if (NewRestTy != Field->RestTy) { if (NewRestTy != Field.RestTy) {
Changed = true; Changed = true;
} }
return Changed ? new TField(Field->Name, NewTy, NewRestTy) : Ty2; return Changed ? new Type(TField(Field.Name, NewTy, NewRestTy)) : Ty2;
} }
case TypeKind::Present: case TypeKind::Present:
{ {
auto Present = static_cast<TPresent*>(Ty2); auto Present = Ty2->asPresent();
auto NewTy = Present->Ty->rewrite(Fn); auto NewTy = Present.Ty->rewrite(Fn, Recursive);
if (NewTy == Present->Ty) { if (NewTy == Present.Ty) {
return Ty2; return Ty2;
} }
return new TPresent(NewTy); return new Type(TPresent(NewTy));
} }
} }
}
void Type::addTypeVars(TVSet& TVs) {
switch (Kind) {
case TypeKind::Var:
TVs.emplace(static_cast<TVar*>(this));
break;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
Arrow->ParamType->addTypeVars(TVs);
Arrow->ReturnType->addTypeVars(TVs);
break;
}
case TypeKind::Con:
break;
case TypeKind::App:
{
auto App = static_cast<TApp*>(this);
App->Op->addTypeVars(TVs);
App->Arg->addTypeVars(TVs);
break;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
Index->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Tuple->ElementTypes) {
Ty->addTypeVars(TVs);
}
break;
}
case TypeKind::Nil:
break;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(this);
Field->Ty->addTypeVars(TVs);
Field->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(this);
Present->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Absent:
break;
}
}
bool Type::hasTypeVar(const TVar* TV) {
switch (Kind) {
case TypeKind::Var:
return static_cast<TVar*>(this)->Id == TV->Id;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
return Arrow->ParamType->hasTypeVar(TV) || Arrow->ReturnType->hasTypeVar(TV);
}
case TypeKind::Con:
return false;
case TypeKind::App:
{
auto App = static_cast<TApp*>(this);
return App->Op->hasTypeVar(TV) || App->Arg->hasTypeVar(TV);
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
return Index->Ty->hasTypeVar(TV);
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Tuple->ElementTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
case TypeKind::Nil:
return false;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(this);
return Field->Ty->hasTypeVar(TV) || Field->RestTy->hasTypeVar(TV);
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(this);
return Present->Ty->hasTypeVar(TV);
}
case TypeKind::Absent:
return false;
}
}
Type* Type::solve() {
return rewrite([](auto Ty) {
if (Ty->getKind() == TypeKind::Var) {
return static_cast<TVar*>(Ty)->find();
}
return Ty;
});
} }
Type* Type::substitute(const TVSub &Sub) { Type* Type::substitute(const TVSub &Sub) {
return rewrite([&](auto Ty) { return rewrite([&](auto Ty) {
if (isa<TVar>(Ty)) { if (Ty->isVar()) {
auto TV = static_cast<TVar*>(Ty); auto Match = Sub.find(Ty);
auto Match = Sub.find(TV);
return Match != Sub.end() ? Match->second->substitute(Sub) : Ty; return Match != Sub.end() ? Match->second->substitute(Sub) : Ty;
} }
return Ty; return Ty;
}); }, false);
} }
Type* Type::resolve(const TypeIndex& Index) const noexcept { Type* Type::resolve(const TypeIndex& Index) const noexcept {
switch (Index.Kind) { switch (Index.Kind) {
case TypeIndexKind::PresentType: case TypeIndexKind::PresentType:
return cast<TPresent>(this)->Ty; return this->asPresent().Ty;
case TypeIndexKind::AppOpType: case TypeIndexKind::AppOpType:
return cast<TApp>(this)->Op; return this->asApp().Op;
case TypeIndexKind::AppArgType: case TypeIndexKind::AppArgType:
return cast<TApp>(this)->Arg; return this->asApp().Arg;
case TypeIndexKind::TupleIndexType: case TypeIndexKind::TupleIndexType:
return cast<TTupleIndex>(this)->Ty; return this->asTupleIndex().Ty;
case TypeIndexKind::TupleElement: case TypeIndexKind::TupleElement:
return cast<TTuple>(this)->ElementTypes[Index.I]; return this->asTuple().ElementTypes[Index.I];
case TypeIndexKind::ArrowParamType: case TypeIndexKind::ArrowParamType:
return cast<TArrow>(this)->ParamType; return this->asArrow().ParamType;
case TypeIndexKind::ArrowReturnType: case TypeIndexKind::ArrowReturnType:
return cast<TArrow>(this)->ReturnType; return this->asArrow().ReturnType;
case TypeIndexKind::FieldType: case TypeIndexKind::FieldType:
return cast<TField>(this)->Ty; return this->asField().Ty;
case TypeIndexKind::FieldRestType: case TypeIndexKind::FieldRestType:
return cast<TField>(this)->RestTy; return this->asField().RestTy;
case TypeIndexKind::End: case TypeIndexKind::End:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
bool Type::operator==(const Type& Other) const noexcept { TVSet Type::getTypeVars() {
switch (Kind) { TVSet Out;
case TypeKind::Var: std::function<void(Type*)> visit = [&](Type* Ty) {
if (Other.Kind != TypeKind::Var) { if (Ty->isVar()) {
return false; Out.emplace(Ty);
} return;
return static_cast<const TVar*>(this)->Id == static_cast<const TVar&>(Other).Id;
case TypeKind::Tuple:
{
if (Other.Kind != TypeKind::Tuple) {
return false;
}
auto A = static_cast<const TTuple&>(*this);
auto B = static_cast<const TTuple&>(Other);
if (A.ElementTypes.size() != B.ElementTypes.size()) {
return false;
}
for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
case TypeKind::TupleIndex:
{
if (Other.Kind != TypeKind::TupleIndex) {
return false;
}
auto A = static_cast<const TTupleIndex&>(*this);
auto B = static_cast<const TTupleIndex&>(Other);
return A.I == B.I && *A.Ty == *B.Ty;
}
case TypeKind::Con:
{
if (Other.Kind != TypeKind::Con) {
return false;
}
auto A = static_cast<const TCon&>(*this);
auto B = static_cast<const TCon&>(Other);
if (A.Id != B.Id) {
return false;
}
return true;
}
case TypeKind::App:
{
if (Other.Kind != TypeKind::App) {
return false;
}
auto A = static_cast<const TApp&>(*this);
auto B = static_cast<const TApp&>(Other);
return *A.Op == *B.Op && *A.Arg == *B.Arg;
}
case TypeKind::Arrow:
{
if (Other.Kind != TypeKind::Arrow) {
return false;
}
auto A = static_cast<const TArrow&>(*this);
auto B = static_cast<const TArrow&>(Other);
return *A.ParamType == *B.ParamType && *A.ReturnType == *B.ReturnType;
}
case TypeKind::Absent:
if (Other.Kind != TypeKind::Absent) {
return false;
}
return true;
case TypeKind::Nil:
if (Other.Kind != TypeKind::Nil) {
return false;
}
return true;
case TypeKind::Present:
{
if (Other.Kind != TypeKind::Present) {
return false;
}
auto A = static_cast<const TPresent&>(*this);
auto B = static_cast<const TPresent&>(Other);
return *A.Ty == *B.Ty;
}
case TypeKind::Field:
{
if (Other.Kind != TypeKind::Field) {
return false;
}
auto A = static_cast<const TField&>(*this);
auto B = static_cast<const TField&>(Other);
return A.Name == B.Name && *A.Ty == *B.Ty && *A.RestTy == *B.RestTy;
}
} }
Ty->visitEachChild(visit);
};
visit(this);
return Out;
} }
TypeIterator Type::begin() { TypeIterator Type::begin() {
@ -407,14 +297,13 @@ namespace bolt {
return TypeIterator { this, getEndIndex() }; return TypeIterator { this, getEndIndex() };
} }
TypeIndex Type::getStartIndex() { TypeIndex Type::getStartIndex() const {
switch (Kind) { switch (Kind) {
case TypeKind::Arrow: case TypeKind::Arrow:
return TypeIndex::forArrowParamType(); return TypeIndex::forArrowParamType();
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
auto Tuple = static_cast<TTuple*>(this); if (asTuple().ElementTypes.empty()) {
if (Tuple->ElementTypes.empty()) {
return TypeIndex(TypeIndexKind::End); return TypeIndex(TypeIndexKind::End);
} }
return TypeIndex::forTupleElement(0); return TypeIndex::forTupleElement(0);
@ -426,29 +315,38 @@ namespace bolt {
} }
} }
TypeIndex Type::getEndIndex() { TypeIndex Type::getEndIndex() const {
return TypeIndex(TypeIndexKind::End); return TypeIndex(TypeIndexKind::End);
} }
bool Type::hasTypeVar(Type* TV) const {
inline Type* TVar::find() { switch (Kind) {
TVar* Curr = this; case TypeKind::Var:
for (;;) { return Var.Id == TV->asVar().Id;
auto Keep = Curr->Parent; case TypeKind::Con:
if (Keep->getKind() != TypeKind::Var || Keep == Curr) { case TypeKind::Absent:
return Keep; case TypeKind::Nil:
return false;
case TypeKind::App:
return App.Op->hasTypeVar(TV) || App.Arg->hasTypeVar(TV);
case TypeKind::Tuple:
for (auto Ty: Tuple.ElementTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
} }
auto TV = static_cast<TVar*>(Keep); }
Curr->Parent = TV->Parent; return false;
Curr = TV; case TypeKind::TupleIndex:
return TupleIndex.Ty->hasTypeVar(TV);
case TypeKind::Field:
return Field.Ty->hasTypeVar(TV) || Field.RestTy->hasTypeVar(TV);
case TypeKind::Arrow:
return Arrow.ParamType->hasTypeVar(TV) || Arrow.ReturnType->hasTypeVar(TV);
case TypeKind::Present:
return Present.Ty->hasTypeVar(TV);
} }
} }
void TVar::set(Type* Ty) {
auto Root = find();
// It is not possible to set a solution twice.
ZEN_ASSERT(Root->getKind() == TypeKind::Var);
static_cast<TVar*>(Root)->Parent = Ty;
} }
}