Make Type a union and fix checking of tuple access

This commit is contained in:
Sam Vervaeck 2024-01-21 00:18:09 +01:00
parent 75124d097b
commit 4e11af005c
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
8 changed files with 1092 additions and 826 deletions

View file

@ -184,6 +184,12 @@ namespace bolt {
template<typename T>
NodeKind getNodeType();
enum NodeFlags {
NodeFlags_TypeIsSolved = 1,
};
using NodeFlagsMask = unsigned;
class Node {
unsigned RefCount = 1;
@ -192,6 +198,7 @@ namespace bolt {
public:
NodeFlagsMask Flags = 0;
Node* Parent = nullptr;
inline void ref() {

View file

@ -178,6 +178,7 @@ namespace bolt {
Type* ListType;
Type* IntType;
Type* StringType;
Type* UnitType;
Graph<Node*> RefGraph;
@ -217,9 +218,9 @@ namespace bolt {
/// Factory methods
TCon* createConType(ByteString Name);
TVar* createTypeVar();
TVarRigid* createRigidVar(ByteString Name);
Type* createConType(ByteString Name);
Type* createTypeVar();
Type* createRigidVar(ByteString Name);
InferContext* createInferContext(
InferContext* Parent = nullptr,
TVSet* TVs = new TVSet,
@ -280,6 +281,11 @@ namespace bolt {
*/
Type* simplifyType(Type* Ty);
/**
* \internal
*/
Type* solveType(Type* Ty);
void check(SourceFile* SF);
inline Type* getBoolType() const {

View file

@ -22,12 +22,17 @@ namespace bolt {
public:
bool FailOnError = false;
inline bool hasError() const noexcept {
return HasError;
}
template<typename D, typename ...Ts>
void add(Ts&&... Args) {
// if (FailOnError) {
// ZEN_PANIC("An error diagnostic caused the program to abort.");
// }
HasError = true;
addDiagnostic(new D { std::forward<Ts>(Args)... });
}

View file

@ -2,7 +2,6 @@
#pragma once
#include <vector>
#include <memory>
#include "bolt/ByteString.hpp"
#include "bolt/String.hpp"
@ -12,15 +11,16 @@
namespace bolt {
enum class DiagnosticKind : unsigned char {
UnexpectedToken,
UnexpectedString,
BindingNotFound,
UnificationError,
TypeclassMissing,
InstanceNotFound,
TupleIndexOutOfRange,
InvalidTypeToTypeclass,
FieldNotFound,
InstanceNotFound,
InvalidTypeToTypeclass,
NotATuple,
TupleIndexOutOfRange,
TypeclassMissing,
UnexpectedString,
UnexpectedToken,
UnificationError,
};
class Diagnostic : std::runtime_error {
@ -168,10 +168,10 @@ namespace bolt {
class TupleIndexOutOfRangeDiagnostic : public Diagnostic {
public:
TTuple* Tuple;
Type* Tuple;
std::size_t I;
inline TupleIndexOutOfRangeDiagnostic(TTuple* Tuple, std::size_t I):
inline TupleIndexOutOfRangeDiagnostic(Type* Tuple, std::size_t I):
Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I) {}
unsigned getCode() const noexcept override {
@ -217,4 +217,18 @@ namespace bolt {
};
class NotATupleDiagnostic : public Diagnostic {
public:
Type* Ty;
inline NotATupleDiagnostic(Type* Ty):
Diagnostic(DiagnosticKind::NotATuple), Ty(Ty) {}
unsigned getCode() const noexcept override {
return 2016;
}
};
}

View file

@ -2,17 +2,21 @@
#pragma once
#include <functional>
#include <type_traits>
#include <vector>
#include <unordered_map>
#include <optional>
#include <unistd.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "zen/config.hpp"
#include "zen/range.hpp"
#include "bolt/CST.hpp"
#include "bolt/ByteString.hpp"
namespace bolt {
class Type;
class TVar;
class TCon;
using TypeclassId = ByteString;
@ -23,7 +27,7 @@ namespace bolt {
using TypeclassId = ByteString;
TypeclassId Id;
std::vector<TVar*> Params;
std::vector<Type*> Params;
bool operator<(const TypeclassSignature& Other) const;
bool operator==(const TypeclassSignature& Other) const;
@ -144,8 +148,8 @@ namespace bolt {
using TypePath = std::vector<TypeIndex>;
using TVSub = std::unordered_map<TVar*, Type*>;
using TVSet = std::unordered_set<TVar*>;
using TVSub = std::unordered_map<Type*, Type*>;
using TVSet = std::unordered_set<Type*>;
enum class TypeKind : unsigned char {
Var,
@ -160,48 +164,402 @@ namespace bolt {
Present,
};
class Type {
class Type;
const TypeKind Kind;
struct TCon {
size_t Id;
ByteString DisplayName;
protected:
bool operator==(const TCon& Other) const;
inline Type(TypeKind Kind):
Kind(Kind) {}
};
public:
struct TApp {
Type* Op;
Type* Arg;
inline TypeKind getKind() const noexcept {
bool operator==(const TApp& Other) const;
};
enum class VarKind {
Rigid,
Unification,
};
struct TVar {
VarKind VK;
size_t Id;
TypeclassContext Context;
std::optional<ByteString> Name;
std::optional<TypeclassContext> Provided;
VarKind getKind() const {
return VK;
}
bool isUni() const {
return VK == VarKind::Unification;
}
bool isRigid() const {
return VK == VarKind::Rigid;
}
bool operator==(const TVar& Other) const;
};
struct TArrow {
Type* ParamType;
Type* ReturnType;
bool operator==(const TArrow& Other) const;
};
struct TTuple {
std::vector<Type*> ElementTypes;
bool operator==(const TTuple& Other) const;
};
struct TTupleIndex {
Type* Ty;
std::size_t I;
bool operator==(const TTupleIndex& Other) const;
};
struct TNil {
bool operator==(const TNil& Other) const;
};
struct TField {
ByteString Name;
Type* Ty;
Type* RestTy;
bool operator==(const TField& Other) const;
};
struct TAbsent {
bool operator==(const TAbsent& Other) const;
};
struct TPresent {
Type* Ty;
bool operator==(const TPresent& Other) const;
};
struct Type {
TypeKind Kind;
Type* Parent = this;
union {
TCon Con;
TApp App;
TVar Var;
TArrow Arrow;
TTuple Tuple;
TTupleIndex TupleIndex;
TNil Nil;
TField Field;
TAbsent Absent;
TPresent Present;
};
Type(TCon&& Con):
Kind(TypeKind::Con), Con(std::move(Con)) {};
Type(TApp&& App):
Kind(TypeKind::App), App(std::move(App)) {};
Type(TVar&& Var):
Kind(TypeKind::Var), Var(std::move(Var)) {};
Type(TArrow&& Arrow):
Kind(TypeKind::Arrow), Arrow(std::move(Arrow)) {};
Type(TTuple&& Tuple):
Kind(TypeKind::Tuple), Tuple(std::move(Tuple)) {};
Type(TTupleIndex&& TupleIndex):
Kind(TypeKind::TupleIndex), TupleIndex(std::move(TupleIndex)) {};
Type(TNil&& Nil):
Kind(TypeKind::Nil), Nil(std::move(Nil)) {};
Type(TField&& Field):
Kind(TypeKind::Field), Field(std::move(Field)) {};
Type(TAbsent&& Absent):
Kind(TypeKind::Absent), Absent(std::move(Absent)) {};
Type(TPresent&& Present):
Kind(TypeKind::Present), Present(std::move(Present)) {};
// Type(TCon Con): Kind(TypeKind::Con) {
// new (&Con)TCon(Con);
// }
// Type(TApp App): Kind(TypeKind::App) {
// new (&App)TApp(App);
// }
// Type(TVar Var): Kind(TypeKind::Var) {
// new (&Var)TVar(Var);
// }
// Type(TArrow Arrow): Kind(TypeKind::Arrow) {
// new (&Arrow)TArrow(Arrow);
// }
// Type(TTuple Tuple): Kind(TypeKind::Tuple) {
// new (&Tuple)TTuple(Tuple);
// }
// Type(TTupleIndex TupleIndex): Kind(TypeKind::TupleIndex) {
// new (&TupleIndex)TTupleIndex(TupleIndex);
// }
// Type(TNil Nil): Kind(TypeKind::Nil) {
// new (&Nil)TNil(Nil);
// }
// Type(TField Field): Kind(TypeKind::Field) {
// new (&Field)TField(Field);
// }
// Type(TAbsent Absent): Kind(TypeKind::Absent) {
// new (&Absent)TAbsent(Absent);
// }
// Type(TPresent Present): Kind(TypeKind::Present) {
// new (&Present)TPresent(Present);
// }
Type(const Type& Other): Kind(Other.Kind) {
switch (Kind) {
case TypeKind::Con:
new (&Con)TCon(Other.Con);
break;
case TypeKind::App:
new (&App)TApp(Other.App);
break;
case TypeKind::Var:
new (&Var)TVar(Other.Var);
break;
case TypeKind::Arrow:
new (&Arrow)TArrow(Other.Arrow);
break;
case TypeKind::Tuple:
new (&Tuple)TTuple(Other.Tuple);
break;
case TypeKind::TupleIndex:
new (&TupleIndex)TTupleIndex(Other.TupleIndex);
break;
case TypeKind::Nil:
new (&Nil)TNil(Other.Nil);
break;
case TypeKind::Field:
new (&Field)TField(Other.Field);
break;
case TypeKind::Absent:
new (&Absent)TAbsent(Other.Absent);
break;
case TypeKind::Present:
new (&Present)TPresent(Other.Present);
break;
}
}
Type(Type&& Other): Kind(std::move(Other.Kind)) {
switch (Kind) {
case TypeKind::Con:
new (&Con)TCon(std::move(Other.Con));
break;
case TypeKind::App:
new (&App)TApp(std::move(Other.App));
break;
case TypeKind::Var:
new (&Var)TVar(std::move(Other.Var));
break;
case TypeKind::Arrow:
new (&Arrow)TArrow(std::move(Other.Arrow));
break;
case TypeKind::Tuple:
new (&Tuple)TTuple(std::move(Other.Tuple));
break;
case TypeKind::TupleIndex:
new (&TupleIndex)TTupleIndex(std::move(Other.TupleIndex));
break;
case TypeKind::Nil:
new (&Nil)TNil(std::move(Other.Nil));
break;
case TypeKind::Field:
new (&Field)TField(std::move(Other.Field));
break;
case TypeKind::Absent:
new (&Absent)TAbsent(std::move(Other.Absent));
break;
case TypeKind::Present:
new (&Present)TPresent(std::move(Other.Present));
break;
}
}
TypeKind getKind() const {
return Kind;
}
bool hasTypeVar(const TVar* TV);
void addTypeVars(TVSet& TVs);
inline TVSet getTypeVars() {
TVSet Out;
addTypeVars(Out);
return Out;
bool isVarRigid() const {
return Kind == TypeKind::Var
&& asVar().getKind() == VarKind::Rigid;
}
/**
* Rewrites the entire substructure of a type to another one.
*
* \param Recursive If true, a succesfull local rewritten type will be again
* rewriten until it encounters some terminals.
*/
Type* rewrite(std::function<Type*(Type*)> Fn, bool Recursive = false);
bool isVar() const {
return Kind == TypeKind::Var;
}
Type* substitute(const TVSub& Sub);
TVar& asVar() {
ZEN_ASSERT(Kind == TypeKind::Var);
return Var;
}
Type* solve();
const TVar& asVar() const {
ZEN_ASSERT(Kind == TypeKind::Var);
return Var;
}
TypeIterator begin();
TypeIterator end();
bool isApp() const {
return Kind == TypeKind::App;
}
TypeIndex getStartIndex();
TypeIndex getEndIndex();
TApp& asApp() {
ZEN_ASSERT(Kind == TypeKind::App);
return App;
}
const TApp& asApp() const {
ZEN_ASSERT(Kind == TypeKind::App);
return App;
}
bool isCon() const {
return Kind == TypeKind::Con;
}
TCon& asCon() {
ZEN_ASSERT(Kind == TypeKind::Con);
return Con;
}
const TCon& asCon() const {
ZEN_ASSERT(Kind == TypeKind::Con);
return Con;
}
bool isArrow() const {
return Kind == TypeKind::Arrow;
}
TArrow& asArrow() {
ZEN_ASSERT(Kind == TypeKind::Arrow);
return Arrow;
}
const TArrow& asArrow() const {
ZEN_ASSERT(Kind == TypeKind::Arrow);
return Arrow;
}
bool isTuple() const {
return Kind == TypeKind::Tuple;
}
TTuple& asTuple() {
ZEN_ASSERT(Kind == TypeKind::Tuple);
return Tuple;
}
const TTuple& asTuple() const {
ZEN_ASSERT(Kind == TypeKind::Tuple);
return Tuple;
}
bool isTupleIndex() const {
return Kind == TypeKind::TupleIndex;
}
TTupleIndex& asTupleIndex() {
ZEN_ASSERT(Kind == TypeKind::TupleIndex);
return TupleIndex;
}
const TTupleIndex& asTupleIndex() const {
ZEN_ASSERT(Kind == TypeKind::TupleIndex);
return TupleIndex;
}
bool isField() const {
return Kind == TypeKind::Field;
}
TField& asField() {
ZEN_ASSERT(Kind == TypeKind::Field);
return Field;
}
const TField& asField() const {
ZEN_ASSERT(Kind == TypeKind::Field);
return Field;
}
bool isAbsent() const {
return Kind == TypeKind::Absent;
}
TAbsent& asAbsent() {
ZEN_ASSERT(Kind == TypeKind::Absent);
return Absent;
}
const TAbsent& asAbsent() const {
ZEN_ASSERT(Kind == TypeKind::Absent);
return Absent;
}
bool isPresent() const {
return Kind == TypeKind::Present;
}
TPresent& asPresent() {
ZEN_ASSERT(Kind == TypeKind::Present);
return Present;
}
const TPresent& asPresent() const {
ZEN_ASSERT(Kind == TypeKind::Present);
return Present;
}
bool isNil() const {
return Kind == TypeKind::Nil;
}
TNil& asNil() {
ZEN_ASSERT(Kind == TypeKind::Nil);
return Nil;
}
const TNil& asNil() const {
ZEN_ASSERT(Kind == TypeKind::Nil);
return Nil;
}
Type* rewrite(std::function<Type*(Type*)> Fn, bool Recursive = true);
Type* resolve(const TypeIndex& Index) const noexcept;
@ -213,207 +571,128 @@ namespace bolt {
return Ty;
}
bool operator==(const Type& Other) const noexcept;
bool operator!=(const Type& Other) const noexcept {
return !(*this == Other);
void set(Type* Ty) {
auto Root = find();
// It is not possible to set a solution twice.
if (isVar()) {
ZEN_ASSERT(Root->isVar());
}
Root->Parent = Ty;
}
};
class TCon : public Type {
public:
const size_t Id;
ByteString DisplayName;
inline TCon(const size_t Id, ByteString DisplayName):
Type(TypeKind::Con), Id(Id), DisplayName(DisplayName) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Con;
Type* find() const {
Type* Curr = const_cast<Type*>(this);
for (;;) {
auto Keep = Curr->Parent;
if (Keep == Curr) {
return Keep;
}
Curr->Parent = Keep->Parent;
Curr = Keep;
}
}
};
bool operator==(const Type& Other) const;
class TApp : public Type {
public:
Type* Op;
Type* Arg;
inline TApp(Type* Op, Type* Arg):
Type(TypeKind::App), Op(Op), Arg(Arg) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::App;
void destroy() {
switch (Kind) {
case TypeKind::Con:
App.~TApp();
break;
case TypeKind::App:
App.~TApp();
break;
case TypeKind::Var:
Var.~TVar();
break;
case TypeKind::Arrow:
Arrow.~TArrow();
break;
case TypeKind::Tuple:
Tuple.~TTuple();
break;
case TypeKind::TupleIndex:
TupleIndex.~TTupleIndex();
break;
case TypeKind::Nil:
Nil.~TNil();
break;
case TypeKind::Field:
Field.~TField();
break;
case TypeKind::Absent:
Absent.~TAbsent();
break;
case TypeKind::Present:
Present.~TPresent();
break;
}
}
};
enum class VarKind {
Rigid,
Unification,
};
class TVar : public Type {
Type* Parent = this;
public:
const size_t Id;
VarKind VK;
TypeclassContext Contexts;
inline TVar(size_t Id, VarKind VK):
Type(TypeKind::Var), Id(Id), VK(VK) {}
inline VarKind getVarKind() const noexcept {
return VK;
Type& operator=(Type& Other) {
destroy();
Kind = Other.Kind;
switch (Kind) {
case TypeKind::Con:
App = Other.App;
break;
case TypeKind::App:
App = Other.App;
break;
case TypeKind::Var:
Var = Other.Var;
break;
case TypeKind::Arrow:
Arrow = Other.Arrow;
break;
case TypeKind::Tuple:
Tuple = Other.Tuple;
break;
case TypeKind::TupleIndex:
TupleIndex = Other.TupleIndex;
break;
case TypeKind::Nil:
Nil = Other.Nil;
break;
case TypeKind::Field:
Field = Other.Field;
break;
case TypeKind::Absent:
Absent = Other.Absent;
break;
case TypeKind::Present:
Present = Other.Present;
break;
}
return *this;
}
inline bool isRigid() const noexcept {
return VK == VarKind::Rigid;
bool hasTypeVar(Type* TV) const;
TypeIterator begin();
TypeIterator end();
TypeIndex getStartIndex() const;
TypeIndex getEndIndex() const;
Type* substitute(const TVSub& Sub);
void visitEachChild(std::function<void(Type*)> Proc);
TVSet getTypeVars();
~Type() {
destroy();
}
Type* find();
void set(Type* Ty);
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Var;
}
};
class TVarRigid : public TVar {
public:
ByteString Name;
TypeclassContext Provided;
inline TVarRigid(size_t Id, ByteString Name):
TVar(Id, VarKind::Rigid), Name(Name) {}
};
class TArrow : public Type {
public:
Type* ParamType;
Type* ReturnType;
inline TArrow(
Type* ParamType,
Type* ReturnType
): Type(TypeKind::Arrow),
ParamType(ParamType),
ReturnType(ReturnType) {}
static Type* build(std::vector<Type*> ParamTypes, Type* ReturnType) {
static Type* buildArrow(std::vector<Type*> ParamTypes, Type* ReturnType) {
Type* Curr = ReturnType;
for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) {
Curr = new TArrow(*Iter, Curr);
Curr = new Type(TArrow(*Iter, Curr));
}
return Curr;
}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Arrow;
}
};
class TTuple : public Type {
public:
std::vector<Type*> ElementTypes;
inline TTuple(std::vector<Type*> ElementTypes):
Type(TypeKind::Tuple), ElementTypes(ElementTypes) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Tuple;
}
};
class TTupleIndex : public Type {
public:
Type* Ty;
std::size_t I;
inline TTupleIndex(Type* Ty, std::size_t I):
Type(TypeKind::TupleIndex), Ty(Ty), I(I) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::TupleIndex;
}
};
class TNil : public Type {
public:
inline TNil():
Type(TypeKind::Nil) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Nil;
}
};
class TField : public Type {
public:
ByteString Name;
Type* Ty;
Type* RestTy;
inline TField(
ByteString Name,
Type* Ty,
Type* RestTy
): Type(TypeKind::Field),
Name(Name),
Ty(Ty),
RestTy(RestTy) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Field;
}
};
class TAbsent : public Type {
public:
inline TAbsent():
Type(TypeKind::Absent) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Absent;
}
};
class TPresent : public Type {
public:
Type* Ty;
inline TPresent(Type* Ty):
Type(TypeKind::Present), Ty(Ty) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Present;
}
};
template<bool IsConst>
@ -426,48 +705,49 @@ namespace bolt {
virtual void enterType(C<Type>* Ty) {}
virtual void exitType(C<Type>* Ty) {}
virtual void visitType(C<Type>* Ty) {
visitEachChild(Ty);
// virtual void visitType(C<Type>* Ty) {
// visitEachChild(Ty);
// }
virtual void visitVarType(C<TVar>& Ty) {
}
virtual void visitVarType(C<TVar>* Ty) {
visitType(Ty);
virtual void visitAppType(C<TApp>& Ty) {
visit(Ty.Op);
visit(Ty.Arg);
}
virtual void visitAppType(C<TApp>* Ty) {
visitType(Ty);
virtual void visitPresentType(C<TPresent>& Ty) {
visit(Ty.Ty);
}
virtual void visitPresentType(C<TPresent>* Ty) {
visitType(Ty);
virtual void visitConType(C<TCon>& Ty) {
}
virtual void visitConType(C<TCon>* Ty) {
visitType(Ty);
virtual void visitArrowType(C<TArrow>& Ty) {
visit(Ty.ParamType);
visit(Ty.ReturnType);
}
virtual void visitArrowType(C<TArrow>* Ty) {
visitType(Ty);
virtual void visitTupleType(C<TTuple>& Ty) {
for (auto ElTy: Ty.ElementTypes) {
visit(ElTy);
}
}
virtual void visitTupleType(C<TTuple>* Ty) {
visitType(Ty);
virtual void visitTupleIndexType(C<TTupleIndex>& Ty) {
visit(Ty.Ty);
}
virtual void visitTupleIndexType(C<TTupleIndex>* Ty) {
visitType(Ty);
virtual void visitAbsentType(C<TAbsent>& Ty) {
}
virtual void visitAbsentType(C<TAbsent>* Ty) {
visitType(Ty);
virtual void visitFieldType(C<TField>& Ty) {
visit(Ty.Ty);
visit(Ty.RestTy);
}
virtual void visitFieldType(C<TField>* Ty) {
visitType(Ty);
}
virtual void visitNilType(C<TNil>* Ty) {
visitType(Ty);
virtual void visitNilType(C<TNil>& Ty) {
}
public:
@ -481,14 +761,14 @@ namespace bolt {
break;
case TypeKind::Arrow:
{
auto Arrow = static_cast<C<TArrow>*>(Ty);
auto& Arrow = Ty->asArrow();
visit(Arrow->ParamType);
visit(Arrow->ReturnType);
break;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<C<TTuple>*>(Ty);
auto& Tuple = Ty->asTuple();
for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) {
visit(Tuple->ElementTypes[I]);
}
@ -496,27 +776,27 @@ namespace bolt {
}
case TypeKind::App:
{
auto App = static_cast<C<TApp>*>(Ty);
auto& App = Ty->asApp();
visit(App->Op);
visit(App->Arg);
break;
}
case TypeKind::Field:
{
auto Field = static_cast<C<TField>*>(Ty);
auto& Field = Ty->asField();
visit(Field->Ty);
visit(Field->RestTy);
break;
}
case TypeKind::Present:
{
auto Present = static_cast<C<TPresent>*>(Ty);
auto& Present = Ty->asPresent();
visit(Present->Ty);
break;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<C<TTupleIndex>*>(Ty);
auto& Index = Ty->asTupleIndex();
visit(Index->Ty);
break;
}
@ -524,37 +804,41 @@ namespace bolt {
}
void visit(C<Type>* Ty) {
// Always look at the most solved solution
Ty = Ty->find();
enterType(Ty);
switch (Ty->getKind()) {
case TypeKind::Present:
visitPresentType(static_cast<C<TPresent>*>(Ty));
visitPresentType(Ty->asPresent());
break;
case TypeKind::Absent:
visitAbsentType(static_cast<C<TAbsent>*>(Ty));
visitAbsentType(Ty->asAbsent());
break;
case TypeKind::Nil:
visitNilType(static_cast<C<TNil>*>(Ty));
visitNilType(Ty->asNil());
break;
case TypeKind::Field:
visitFieldType(static_cast<C<TField>*>(Ty));
visitFieldType(Ty->asField());
break;
case TypeKind::Con:
visitConType(static_cast<C<TCon>*>(Ty));
visitConType(Ty->asCon());
break;
case TypeKind::Arrow:
visitArrowType(static_cast<C<TArrow>*>(Ty));
visitArrowType(Ty->asArrow());
break;
case TypeKind::Var:
visitVarType(static_cast<C<TVar>*>(Ty));
visitVarType(Ty->asVar());
break;
case TypeKind::Tuple:
visitTupleType(static_cast<C<TTuple>*>(Ty));
visitTupleType(Ty->asTuple());
break;
case TypeKind::App:
visitAppType(static_cast<C<TApp>*>(Ty));
visitAppType(Ty->asApp());
break;
case TypeKind::TupleIndex:
visitTupleIndexType(static_cast<C<TTupleIndex>*>(Ty));
visitTupleIndexType(Ty->asTupleIndex());
break;
}
exitType(Ty);
@ -567,11 +851,4 @@ namespace bolt {
using TypeVisitor = TypeVisitorBase<false>;
using ConstTypeVisitor = TypeVisitorBase<true>;
// template<typename T>
// struct DerefHash {
// std::size_t operator()(const T& Value) const noexcept {
// return std::hash<decltype(*Value)>{}(*Value);
// }
// };
}

View file

@ -1,13 +1,11 @@
#include <algorithm>
#include <iterator>
#include <stack>
#include <map>
#include "bolt/Type.hpp"
#include "zen/config.hpp"
#include "zen/range.hpp"
#include "bolt/Type.hpp"
#include "bolt/CSTVisitor.hpp"
#include "bolt/DiagnosticEngine.hpp"
#include "bolt/Diagnostics.hpp"
@ -39,29 +37,30 @@ namespace bolt {
Type* Checker::simplifyType(Type* Ty) {
return Ty->rewrite([&](auto Ty) {
Ty = Ty->find();
if (Ty->getKind() == TypeKind::Var) {
Ty = static_cast<TVar*>(Ty)->find();
}
if (Ty->getKind() == TypeKind::TupleIndex) {
auto Index = static_cast<TTupleIndex*>(Ty);
auto MaybeTuple = simplifyType(Index->Ty);
if (MaybeTuple->getKind() == TypeKind::Tuple) {
auto Tuple = static_cast<TTuple*>(MaybeTuple);
if (Index->I >= Tuple->ElementTypes.size()) {
DE.add<TupleIndexOutOfRangeDiagnostic>(Tuple, Index->I);
if (Ty->isTupleIndex()) {
auto& Index = Ty->asTupleIndex();
auto MaybeTuple = simplifyType(Index.Ty);
if (MaybeTuple->isTuple()) {
auto& Tuple = MaybeTuple->asTuple();
if (Index.I >= Tuple.ElementTypes.size()) {
DE.add<TupleIndexOutOfRangeDiagnostic>(MaybeTuple, Index.I);
} else {
Ty = simplifyType(Tuple->ElementTypes[Index->I]);
auto ElementTy = simplifyType(Tuple.ElementTypes[Index.I]);
Ty->set(ElementTy);
Ty = ElementTy;
}
} else if (!MaybeTuple->isVar()) {
DE.add<NotATupleDiagnostic>(MaybeTuple);
}
}
return Ty;
}
}, /*Recursive=*/true);
Type* Checker::solveType(Type* Ty) {
return Ty->rewrite([this](auto Ty) { return simplifyType(Ty); }, true);
}
Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE):
@ -70,6 +69,7 @@ namespace bolt {
IntType = createConType("Int");
StringType = createConType("String");
ListType = createConType("List");
UnitType = new Type(TTuple({}));
}
Scheme* Checker::lookup(ByteString Name) {
@ -293,7 +293,7 @@ namespace bolt {
setContext(Decl->Ctx);
std::vector<TVar*> Vars;
std::vector<Type*> Vars;
for (auto TE: Decl->TVs) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV);
@ -312,13 +312,20 @@ namespace bolt {
auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member);
auto RetTy = Ty;
for (auto Var: Vars) {
RetTy = new TApp(RetTy, Var);
RetTy = new Type(TApp(RetTy, Var));
}
std::vector<Type*> ParamTypes;
for (auto Element: TupleMember->Elements) {
ParamTypes.push_back(inferTypeExpression(Element));
}
Decl->Ctx->Parent->add(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, TArrow::build(ParamTypes, RetTy)));
Decl->Ctx->Parent->add(
TupleMember->Name->getCanonicalText(),
new Forall(
Decl->Ctx->TVs,
Decl->Ctx->Constraints,
Type::buildArrow(ParamTypes, RetTy)
)
);
break;
}
case NodeKind::RecordVariantDeclarationMember:
@ -342,7 +349,7 @@ namespace bolt {
setContext(Decl->Ctx);
std::vector<TVar*> Vars;
std::vector<Type*> Vars;
for (auto TE: Decl->Vars) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Vars.push_back(TV);
@ -355,15 +362,28 @@ namespace bolt {
Decl->Ctx->Parent->add(Name, new Forall(Ty));
// Corresponds to the logic of one branch of a VariantDeclarationMember
Type* FieldsTy = new TNil();
Type* FieldsTy = new Type(TNil());
for (auto Field: Decl->Fields) {
FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy);
FieldsTy = new Type(
TField(
Field->Name->getCanonicalText(),
new Type(TPresent(inferTypeExpression(Field->TypeExpression))),
FieldsTy
)
);
}
Type* RetTy = Ty;
for (auto TV: Vars) {
RetTy = new TApp(RetTy, TV);
RetTy = new Type(TApp(RetTy, TV));
}
Decl->Ctx->Parent->add(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(FieldsTy, RetTy)));
Decl->Ctx->Parent->add(
Name,
new Forall(
Decl->Ctx->TVs,
Decl->Ctx->Constraints,
new Type(TArrow(FieldsTy, RetTy))
)
);
popContext();
break;
@ -444,11 +464,11 @@ namespace bolt {
auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) {
auto Id = Class->Name->getCanonicalText();
auto Ctx = &getContext();
std::vector<TVar*> Out;
std::vector<Type*> Out;
for (auto TE: Class->TypeVars) {
auto Name = TE->Name->getCanonicalText();
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
TV->Contexts.emplace(Id);
TV->asVar().Context.emplace(Id);
Ctx->add(Name, new Forall(TV));
Out.push_back(TV);
}
@ -586,7 +606,7 @@ namespace bolt {
RetType = createTypeVar();
}
makeEqual(Decl->getType(), TArrow::build(ParamTypes, RetType), Decl);
makeEqual(Decl->getType(), Type::buildArrow(ParamTypes, RetType), Decl);
setContext(OldCtx);
}
@ -648,8 +668,8 @@ namespace bolt {
if (RetStmt->Expression) {
makeEqual(inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression);
} else {
ReturnType = new TTuple({});
makeEqual(new TTuple({}), getReturnType(), N);
ReturnType = UnitType;
makeEqual(UnitType, getReturnType(), N);
}
break;
}
@ -691,18 +711,18 @@ namespace bolt {
}
TCon* Checker::createConType(ByteString Name) {
return new TCon(NextConTypeId++, Name);
Type* Checker::createConType(ByteString Name) {
return new Type(TCon(NextConTypeId++, Name));
}
TVarRigid* Checker::createRigidVar(ByteString Name) {
auto TV = new TVarRigid(NextTypeVarId++, Name);
Type* Checker::createRigidVar(ByteString Name) {
auto TV = new Type(TVar(VarKind::Rigid, NextTypeVarId++, {}, Name, {{}}));
getContext().TVs->emplace(TV);
return TV;
}
TVar* Checker::createTypeVar() {
auto TV = new TVar(NextTypeVarId++, VarKind::Unification);
Type* Checker::createTypeVar() {
auto TV = new Type(TVar(VarKind::Unification, NextTypeVarId++, {}));
getContext().TVs->emplace(TV);
return TV;
}
@ -727,7 +747,7 @@ namespace bolt {
for (auto TV: *F->TVs) {
auto Fresh = createTypeVar();
// std::cerr << describe(TV) << " => " << describe(Fresh) << std::endl;
Fresh->Contexts = TV->Contexts;
Fresh->asVar().Context = TV->asVar().Context;
Sub[TV] = Fresh;
}
@ -736,8 +756,8 @@ namespace bolt {
// FIXME improve this
if (Constraint->getKind() == ConstraintKind::Equal) {
auto Eq = static_cast<CEqual*>(Constraint);
Eq->Left = simplifyType(Eq->Left);
Eq->Right = simplifyType(Eq->Right);
Eq->Left = solveType(Eq->Left);
Eq->Right = solveType(Eq->Right);
}
auto NewConstraint = Constraint->substitute(Sub);
@ -752,11 +772,11 @@ namespace bolt {
addConstraint(NewConstraint);
}
// Note the call to simplify? This is because constraints may have already
// This call to solve happens because constraints may have already
// been solved, with some unification variables being erased. To make
// sure we instantiate unification variables that are still in use
// we solve before substituting.
return simplifyType(F->Type)->substitute(Sub);
return solveType(F->Type)->substitute(Sub);
}
}
@ -771,10 +791,8 @@ namespace bolt {
std::vector<Type*> Types;
for (auto TE: D->TEs) {
auto Ty = inferTypeExpression(TE);
ZEN_ASSERT(Ty->getKind() == TypeKind::Var && static_cast<TVar*>(Ty)->isRigid());
auto TV = static_cast<TVarRigid*>(Ty);
TV->Provided.emplace(D->Name->getCanonicalText());
Types.push_back(TV);
Ty->asVar().Provided->emplace(D->Name->getCanonicalText());
Types.push_back(Ty);
}
break;
}
@ -813,7 +831,7 @@ namespace bolt {
auto AppTE = static_cast<AppTypeExpression*>(N);
Type* Ty = inferTypeExpression(AppTE->Op, IsPoly);
for (auto Arg: AppTE->Args) {
Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly));
Ty = new Type(TApp(Ty, inferTypeExpression(Arg, IsPoly)));
}
N->setType(Ty);
return Ty;
@ -830,9 +848,9 @@ namespace bolt {
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty));
}
ZEN_ASSERT(Ty->getKind() == TypeKind::Var);
ZEN_ASSERT(Ty->isVar());
N->setType(Ty);
return static_cast<TVar*>(Ty);
return Ty;
}
case NodeKind::TupleTypeExpression:
@ -842,7 +860,7 @@ namespace bolt {
for (auto [TE, Comma]: TupleTE->Elements) {
ElementTypes.push_back(inferTypeExpression(TE, IsPoly));
}
auto Ty = new TTuple(ElementTypes);
auto Ty = new Type(TTuple(ElementTypes));
N->setType(Ty);
return Ty;
}
@ -863,7 +881,7 @@ namespace bolt {
ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly));
}
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly);
auto Ty = TArrow::build(ParamTypes, ReturnType);
auto Ty = Type::buildArrow(ParamTypes, ReturnType);
N->setType(Ty);
return Ty;
}
@ -886,14 +904,14 @@ namespace bolt {
}
Type* sortRow(Type* Ty) {
std::map<ByteString, TField*> Fields;
while (Ty->getKind() == TypeKind::Field) {
auto Field = static_cast<TField*>(Ty);
Fields.emplace(Field->Name, Field);
Ty = Field->RestTy;
std::map<ByteString, Type*> Fields;
while (Ty->isField()) {
auto& Field = Ty->asField();
Fields.emplace(Field.Name, Ty);
Ty = Field.RestTy;
}
for (auto [Name, Field]: Fields) {
Ty = new TField(Name, Field->Ty, Ty);
Ty = new Type(TField(Name, Field->asField().Ty, Ty));
}
return Ty;
}
@ -930,7 +948,7 @@ namespace bolt {
setContext(OldCtx);
}
if (!Match->Value) {
Ty = new TArrow(ValTy, Ty);
Ty = new Type(TArrow(ValTy, Ty));
}
break;
}
@ -938,9 +956,13 @@ namespace bolt {
case NodeKind::RecordExpression:
{
auto Record = static_cast<RecordExpression*>(X);
Ty = new TNil();
Ty = new Type(TNil());
for (auto [Field, Comma]: Record->Fields) {
Ty = new TField(Field->Name->getCanonicalText(), new TPresent(inferExpression(Field->getExpression())), Ty);
Ty = new Type(TField(
Field->Name->getCanonicalText(),
new Type(TPresent(inferExpression(Field->getExpression()))),
Ty
));
}
Ty = sortRow(Ty);
break;
@ -998,7 +1020,7 @@ namespace bolt {
for (auto Arg: Call->Args) {
ArgTypes.push_back(inferExpression(Arg));
}
makeEqual(OpTy, TArrow::build(ArgTypes, Ty), X);
makeEqual(OpTy, Type::buildArrow(ArgTypes, Ty), X);
break;
}
@ -1008,14 +1030,15 @@ namespace bolt {
auto Scm = lookup(Infix->Operator->getText());
if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator);
return createTypeVar();
Ty = createTypeVar();
break;
}
auto OpTy = instantiate(Scm, Infix->Operator);
Ty = createTypeVar();
std::vector<Type*> ArgTys;
ArgTys.push_back(inferExpression(Infix->Left));
ArgTys.push_back(inferExpression(Infix->Right));
makeEqual(TArrow::build(ArgTys, Ty), OpTy, X);
makeEqual(Type::buildArrow(ArgTys, Ty), OpTy, X);
break;
}
@ -1026,7 +1049,7 @@ namespace bolt {
for (auto [E, Comma]: Tuple->Elements) {
Types.push_back(inferExpression(E));
}
Ty = new TTuple(Types);
Ty = new Type(TTuple(Types));
break;
}
@ -1038,7 +1061,7 @@ namespace bolt {
case NodeKind::IntegerLiteral:
{
auto I = static_cast<IntegerLiteral*>(Member->Name);
Ty = new TTupleIndex(ExprTy, I->getInteger());
Ty = new Type(TTupleIndex(ExprTy, I->getInteger()));
break;
}
case NodeKind::Identifier:
@ -1046,7 +1069,7 @@ namespace bolt {
auto K = static_cast<Identifier*>(Member->Name);
Ty = createTypeVar();
auto RestTy = createTypeVar();
makeEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member);
makeEqual(new Type(TField(K->getCanonicalText(), Ty, RestTy)), ExprTy, Member);
break;
}
default:
@ -1102,7 +1125,7 @@ namespace bolt {
}
auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar();
makeEqual(Ty, TArrow::build(ParamTypes, RetTy), P);
makeEqual(Ty, Type::buildArrow(ParamTypes, RetTy), P);
return RetTy;
}
@ -1113,7 +1136,7 @@ namespace bolt {
for (auto [Element, Comma]: P->Elements) {
ElementTypes.push_back(inferPattern(Element));
}
return new TTuple(ElementTypes);
return new Type(TTuple(ElementTypes));
}
case NodeKind::ListPattern:
@ -1123,7 +1146,7 @@ namespace bolt {
for (auto [Element, Separator]: P->Elements) {
makeEqual(ElementType, inferPattern(Element), P);
}
return new TApp(ListType, ElementType);
return new Type(TApp(ListType, ElementType));
}
case NodeKind::NestedPattern:
@ -1204,7 +1227,14 @@ namespace bolt {
}
Type* Checker::getType(TypedNode *Node) {
return Node->getType()->solve();
auto Ty = Node->getType();
if (Node->Flags & NodeFlags_TypeIsSolved) {
return Ty;
}
Ty = solveType(Ty);
Node->setType(Ty);
Node->Flags |= NodeFlags_TypeIsSolved;
return Ty;
}
void Checker::check(SourceFile *SF) {
@ -1217,11 +1247,11 @@ namespace bolt {
addBinding("True", new Forall(BoolType));
addBinding("False", new Forall(BoolType));
auto A = createTypeVar();
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, TArrow::build({ A, A }, BoolType)));
addBinding("+", new Forall(TArrow::build({ IntType, IntType }, IntType)));
addBinding("-", new Forall(TArrow::build({ IntType, IntType }, IntType)));
addBinding("*", new Forall(TArrow::build({ IntType, IntType }, IntType)));
addBinding("/", new Forall(TArrow::build({ IntType, IntType }, IntType)));
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)));
addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)));
populate(SF);
forwardDeclare(SF);
auto SCCs = RefGraph.strongconnect();
@ -1243,6 +1273,27 @@ namespace bolt {
ActiveContext = nullptr;
solve(new CMany(*SF->Ctx->Constraints));
class Visitor : public CSTVisitor<Visitor> {
Checker& C;
public:
Visitor(Checker& C):
C(C) {}
void visitAnnotation(Annotation* A) {
}
void visitExpression(Expression* X) {
C.getType(X);
}
} V(*this);
V.visit(SF);
}
void Checker::solve(Constraint* Constraint) {
@ -1281,10 +1332,10 @@ namespace bolt {
}
bool assignableTo(Type* A, Type* B) {
if (isa<TCon>(A) && isa<TCon>(B)) {
auto Con1 = cast<TCon>(A);
auto Con2 = cast<TCon>(B);
if (Con1->Id != Con2-> Id) {
if (A->isCon() && B->isCon()) {
auto& Con1 = A->asCon();
auto& Con2 = B->asCon();
if (Con1.Id != Con2.Id) {
return false;
}
return true;
@ -1295,13 +1346,15 @@ namespace bolt {
class ArrowCursor {
std::stack<std::tuple<TArrow*, bool>> Stack;
/// Types on this stack are guaranteed to be arrow types.
std::stack<std::tuple<Type*, bool>> Stack;
TypePath& Path;
std::size_t I;
public:
ArrowCursor(TArrow* Arr, TypePath& Path):
ArrowCursor(Type* Arr, TypePath& Path):
Path(Path) {
Stack.push({ Arr, true });
Path.push_back(Arr->getStartIndex());
@ -1323,9 +1376,9 @@ namespace bolt {
continue;
}
Ty = Arrow->resolve(Index);
if (isa<TArrow>(Ty)) {
if (Ty->isArrow()) {
auto NewIndex = Arrow->getStartIndex();
Stack.push({ static_cast<TArrow*>(Ty), true });
Stack.push({ Ty, true });
Path.push_back(NewIndex);
} else {
return Ty;
@ -1390,40 +1443,36 @@ namespace bolt {
}
TypeSig getTypeSig(Type* Ty) {
struct Visitor : TypeVisitor {
Type* Op = nullptr;
std::vector<Type*> Args;
void visitType(Type* Ty) override {
if (!Op) {
std::function<void(Type*)> Visit = [&](Type* Ty) {
if (Ty->isApp()) {
Visit(Ty->asApp().Op);
Visit(Ty->asApp().Arg);
} else if (!Op) {
Op = Ty;
} else {
Args.push_back(Ty);
}
}
void visitAppType(TApp* Ty) override {
visitEachChild(Ty);
}
};
Visitor V;
V.visit(Ty);
return TypeSig { Ty, V.Op, V.Args };
Visit(Ty);
return TypeSig { Ty, Op, Args };
}
void propagateClasses(std::unordered_set<TypeclassId>& Classes, Type* Ty) {
if (isa<TVar>(Ty)) {
auto TV = cast<TVar>(Ty);
if (Ty->isVar()) {
auto TV = Ty->asVar();
for (auto Class: Classes) {
TV->Contexts.emplace(Class);
TV.Context.emplace(Class);
}
if (TV->isRigid()) {
auto RV = static_cast<TVarRigid*>(Ty);
for (auto Id: RV->Contexts) {
if (!RV->Provided.count(Id)) {
C.DE.add<TypeclassMissingDiagnostic>(TypeclassSignature { Id, { RV } }, getSource());
if (TV.isRigid()) {
for (auto Id: TV.Context) {
if (!TV.Provided->count(Id)) {
C.DE.add<TypeclassMissingDiagnostic>(TypeclassSignature { Id, { Ty } }, getSource());
}
}
}
} else if (isa<TCon>(Ty) || isa<TApp>(Ty)) {
} else if (Ty->isCon() || Ty->isApp()) {
auto Sig = getTypeSig(Ty);
for (auto Class: Classes) {
propagateClassTycon(Class, Sig);
@ -1450,13 +1499,13 @@ namespace bolt {
*
* Other side effects may occur.
*/
void join(TVar* TV, Type* Ty) {
void join(Type* TV, Type* Ty) {
// std::cerr << describe(TV) << " => " << describe(Ty) << std::endl;
TV->set(Ty);
propagateClasses(TV->Contexts, Ty);
propagateClasses(TV->asVar().Context, Ty);
// This is a very specific adjustment that is critical to the
// well-functioning of the infer/unify algorithm. When addConstraint() is
@ -1480,21 +1529,21 @@ namespace bolt {
};
bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) {
if (isa<TAbsent>(A) && isa<TAbsent>(B)) {
if (A->isAbsent() && B->isAbsent()) {
return true;
}
if (isa<TAbsent>(B)) {
if (B->isAbsent()) {
std::swap(A, B);
DidSwap = !DidSwap;
}
if (isa<TAbsent>(A)) {
auto Present = static_cast<TPresent*>(B);
C.DE.add<FieldNotFoundDiagnostic>(CurrentFieldName, C.simplifyType(getLeft()), LeftPath, getSource());
if (A->isAbsent()) {
auto& Present = B->asPresent();
C.DE.add<FieldNotFoundDiagnostic>(CurrentFieldName, C.solveType(getLeft()), LeftPath, getSource());
return false;
}
auto Present1 = static_cast<TPresent*>(A);
auto Present2 = static_cast<TPresent*>(B);
return unify(Present1->Ty, Present2->Ty, DidSwap);
auto& Present1 = A->asPresent();
auto& Present2 = B->asPresent();
return unify(Present1.Ty, Present2.Ty, DidSwap);
};
bool Unifier::unify(Type* A, Type* B, bool DidSwap) {
@ -1504,8 +1553,8 @@ namespace bolt {
auto unifyError = [&]() {
C.DE.add<UnificationErrorDiagnostic>(
C.simplifyType(Constraint->Left),
C.simplifyType(Constraint->Right),
Constraint->Left,
Constraint->Right,
LeftPath,
RightPath,
Constraint->Source
@ -1549,50 +1598,50 @@ namespace bolt {
DidSwap = !DidSwap;
};
if (isa<TVar>(A) && isa<TVar>(B)) {
auto Var1 = static_cast<TVar*>(A);
auto Var2 = static_cast<TVar*>(B);
if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) {
if (Var1->Id != Var2->Id) {
if (A->isVar() && B->isVar()) {
auto& Var1 = A->asVar();
auto& Var2 = B->asVar();
if (Var1.isRigid() && Var2.isRigid()) {
if (Var1.Id != Var2.Id) {
unifyError();
return false;
}
return true;
}
TVar* To;
TVar* From;
if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) {
To = Var1;
From = Var2;
Type* To;
Type* From;
if (Var1.isRigid() && Var2.isUni()) {
To = A;
From = B;
} else {
// Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification
// Either way, Var1, being Unification, is a good candidate for being unified away
To = Var2;
From = Var1;
To = B;
From = A;
}
if (From->Id != To->Id) {
if (From->asVar().Id != To->asVar().Id) {
join(From, To);
}
return true;
}
if (isa<TVar>(B)) {
if (B->isVar()) {
swap();
}
if (isa<TVar>(A)) {
if (A->isVar()) {
auto TV = static_cast<TVar*>(A);
auto& TV = A->asVar();
// Rigid type variables can never unify with antything else than what we
// have already handled in the previous if-statement, so issue an error.
if (TV->getVarKind() == VarKind::Rigid) {
if (TV.isRigid()) {
unifyError();
return false;
}
// Occurs check
if (B->hasTypeVar(TV)) {
if (B->hasTypeVar(A)) {
// NOTE Just like GHC, we just display an error message indicating that
// A cannot match B, e.g. a cannot match [a]. It looks much better
// than obsure references to an occurs check
@ -1600,25 +1649,25 @@ namespace bolt {
return false;
}
join(TV, B);
join(A, B);
return true;
}
if (isa<TArrow>(A) && isa<TArrow>(B)) {
auto Arrow1 = static_cast<TArrow*>(A);
auto Arrow2 = static_cast<TArrow*>(B);
if (A->isArrow() && B->isArrow()) {
auto& Arrow1 = A->asArrow();
auto& Arrow2 = B->asArrow();
bool Success = true;
LeftPath.push_back(TypeIndex::forArrowParamType());
RightPath.push_back(TypeIndex::forArrowParamType());
if (!unify(Arrow1->ParamType, Arrow2->ParamType, DidSwap)) {
if (!unify(Arrow1.ParamType, Arrow2.ParamType, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
LeftPath.push_back(TypeIndex::forArrowReturnType());
RightPath.push_back(TypeIndex::forArrowReturnType());
if (!unify(Arrow1->ReturnType, Arrow2->ReturnType, DidSwap)) {
if (!unify(Arrow1.ReturnType, Arrow2.ReturnType, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
@ -1626,20 +1675,20 @@ namespace bolt {
return Success;
}
if (isa<TApp>(A) && isa<TApp>(B)) {
auto App1 = static_cast<TApp*>(A);
auto App2 = static_cast<TApp*>(B);
if (A->isApp() && B->isApp()) {
auto& App1 = A->asApp();
auto& App2 = B->asApp();
bool Success = true;
LeftPath.push_back(TypeIndex::forAppOpType());
RightPath.push_back(TypeIndex::forAppOpType());
if (!unify(App1->Op, App2->Op, DidSwap)) {
if (!unify(App1.Op, App2.Op, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
LeftPath.push_back(TypeIndex::forAppArgType());
RightPath.push_back(TypeIndex::forAppArgType());
if (!unify(App1->Arg, App2->Arg, DidSwap)) {
if (!unify(App1.Arg, App2.Arg, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
@ -1647,19 +1696,19 @@ namespace bolt {
return Success;
}
if (isa<TTuple>(A) && isa<TTuple>(B)) {
auto Tuple1 = static_cast<TTuple*>(A);
auto Tuple2 = static_cast<TTuple*>(B);
if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) {
if (A->isTuple() && B->isTuple()) {
auto& Tuple1 = A->asTuple();
auto& Tuple2 = B->asTuple();
if (Tuple1.ElementTypes.size() != Tuple2.ElementTypes.size()) {
unifyError();
return false;
}
auto Count = Tuple1->ElementTypes.size();
auto Count = Tuple1.ElementTypes.size();
bool Success = true;
for (size_t I = 0; I < Count; I++) {
LeftPath.push_back(TypeIndex::forTupleElement(I));
RightPath.push_back(TypeIndex::forTupleElement(I));
if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], DidSwap)) {
if (!unify(Tuple1.ElementTypes[I], Tuple2.ElementTypes[I], DidSwap)) {
Success = false;
}
LeftPath.pop_back();
@ -1668,84 +1717,85 @@ namespace bolt {
return Success;
}
if (isa<TTupleIndex>(A) || isa<TTupleIndex>(B)) {
if (A->isTupleIndex() || B->isTupleIndex()) {
// Type(s) could not be simplified at the beginning of this function,
// so we have to re-visit the constraint when there is more information.
C.Queue.push_back(Constraint);
return true;
}
// if (isa<TTupleIndex>(A) && isa<TTupleIndex>(B)) {
// This does not work because it ignores the indices
// if (A->isTupleIndex() && B->isTupleIndex()) {
// auto Index1 = static_cast<TTupleIndex*>(A);
// auto Index2 = static_cast<TTupleIndex*>(B);
// return unify(Index1->Ty, Index2->Ty, Source);
// }
if (isa<TCon>(A) && isa<TCon>(B)) {
auto Con1 = static_cast<TCon*>(A);
auto Con2 = static_cast<TCon*>(B);
if (Con1->Id != Con2->Id) {
if (A->isCon() && B->isCon()) {
auto& Con1 = A->asCon();
auto& Con2 = B->asCon();
if (Con1.Id != Con2.Id) {
unifyError();
return false;
}
return true;
}
if (isa<TNil>(A) && isa<TNil>(B)) {
if (A->isNil() && B->isNil()) {
return true;
}
if (isa<TField>(A) && isa<TField>(B)) {
auto Field1 = static_cast<TField*>(A);
auto Field2 = static_cast<TField*>(B);
if (A->isField() && B->isField()) {
auto& Field1 = A->asField();
auto& Field2 = B->asField();
bool Success = true;
if (Field1->Name == Field2->Name) {
if (Field1.Name == Field2.Name) {
LeftPath.push_back(TypeIndex::forFieldType());
RightPath.push_back(TypeIndex::forFieldType());
CurrentFieldName = Field1->Name;
if (!unifyField(Field1->Ty, Field2->Ty, DidSwap)) {
CurrentFieldName = Field1.Name;
if (!unifyField(Field1.Ty, Field2.Ty, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
LeftPath.push_back(TypeIndex::forFieldRest());
RightPath.push_back(TypeIndex::forFieldRest());
if (!unify(Field1->RestTy, Field2->RestTy, DidSwap)) {
if (!unify(Field1.RestTy, Field2.RestTy, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
return Success;
}
auto NewRestTy = new TVar(C.NextTypeVarId++, VarKind::Unification);
auto NewRestTy = new Type(TVar(VarKind::Unification, C.NextTypeVarId++));
pushLeft(TypeIndex::forFieldRest());
if (!unify(Field1->RestTy, new TField(Field2->Name, Field2->Ty, NewRestTy), DidSwap)) {
if (!unify(Field1.RestTy, new Type(TField(Field2.Name, Field2.Ty, NewRestTy)), DidSwap)) {
Success = false;
}
popLeft();
pushRight(TypeIndex::forFieldRest());
if (!unify(new TField(Field1->Name, Field1->Ty, NewRestTy), Field2->RestTy, DidSwap)) {
if (!unify(new Type(TField(Field1.Name, Field1.Ty, NewRestTy)), Field2.RestTy, DidSwap)) {
Success = false;
}
popRight();
return Success;
}
if (isa<TNil>(A) && isa<TField>(B)) {
if (A->isNil() && B->isField()) {
swap();
}
if (isa<TField>(A) && isa<TNil>(B)) {
auto Field = static_cast<TField*>(A);
if (A->isField() && B->isNil()) {
auto& Field = A->asField();
bool Success = true;
pushLeft(TypeIndex::forFieldType());
CurrentFieldName = Field->Name;
if (!unifyField(Field->Ty, new TAbsent, DidSwap)) {
CurrentFieldName = Field.Name;
if (!unifyField(Field.Ty, new Type(TAbsent()), DidSwap)) {
Success = false;
}
popLeft();
pushLeft(TypeIndex::forFieldRest());
if (!unify(Field->RestTy, B, DidSwap)) {
if (!unify(Field.RestTy, B, DidSwap)) {
Success = false;
}
popLeft();
@ -1762,6 +1812,5 @@ namespace bolt {
A.unify();
}
}

View file

@ -193,41 +193,42 @@ namespace bolt {
}
std::string describe(const Type* Ty) {
Ty = Ty->find();
switch (Ty->getKind()) {
case TypeKind::Var:
{
auto TV = static_cast<const TVar*>(Ty);
if (TV->getVarKind() == VarKind::Rigid) {
return static_cast<const TVarRigid*>(TV)->Name;
auto TV = Ty->asVar();
if (TV.isRigid()) {
return *TV.Name;
}
return "a" + std::to_string(TV->Id);
return "a" + std::to_string(TV.Id);
}
case TypeKind::Arrow:
{
auto Y = static_cast<const TArrow*>(Ty);
auto Y = Ty->asArrow();
std::ostringstream Out;
Out << describe(Y->ParamType) << " -> " << describe(Y->ReturnType);
Out << describe(Y.ParamType) << " -> " << describe(Y.ReturnType);
return Out.str();
}
case TypeKind::Con:
{
auto Y = static_cast<const TCon*>(Ty);
return Y->DisplayName;
auto Y = Ty->asCon();
return Y.DisplayName;
}
case TypeKind::App:
{
auto Y = static_cast<const TApp*>(Ty);
return describe(Y->Op) + " " + describe(Y->Arg);
auto Y = Ty->asApp();
return describe(Y.Op) + " " + describe(Y.Arg);
}
case TypeKind::Tuple:
{
std::ostringstream Out;
auto Y = static_cast<const TTuple*>(Ty);
auto Y = Ty->asTuple();
Out << "(";
if (Y->ElementTypes.size()) {
auto Iter = Y->ElementTypes.begin();
if (Y.ElementTypes.size()) {
auto Iter = Y.ElementTypes.begin();
Out << describe(*Iter++);
while (Iter != Y->ElementTypes.end()) {
while (Iter != Y.ElementTypes.end()) {
Out << ", " << describe(*Iter++);
}
}
@ -236,8 +237,8 @@ namespace bolt {
}
case TypeKind::TupleIndex:
{
auto Y = static_cast<const TTupleIndex*>(Ty);
return describe(Y->Ty) + "." + std::to_string(Y->I);
auto Y = Ty->asTupleIndex();
return describe(Y.Ty) + "." + std::to_string(Y.I);
}
case TypeKind::Nil:
return "{}";
@ -245,19 +246,19 @@ namespace bolt {
return "Abs";
case TypeKind::Present:
{
auto Y = static_cast<const TPresent*>(Ty);
return describe(Y->Ty);
auto Y = Ty->asPresent();
return describe(Y.Ty);
}
case TypeKind::Field:
{
auto Y = static_cast<const TField*>(Ty);
auto Y = Ty->asField();
std::ostringstream out;
out << "{ " << Y->Name << ": " << describe(Y->Ty);
Ty = Y->RestTy;
out << "{ " << Y.Name << ": " << describe(Y.Ty);
Ty = Y.RestTy;
while (Ty->getKind() == TypeKind::Field) {
auto Y = static_cast<const TField*>(Ty);
out << "; " + Y->Name + ": " + describe(Y->Ty);
Ty = Y->RestTy;
auto Y = Ty->asField();
out << "; " + Y.Name + ": " + describe(Y.Ty);
Ty = Y.RestTy;
}
if (Ty->getKind() != TypeKind::Nil) {
out << "; " + describe(Ty);
@ -561,53 +562,52 @@ namespace bolt {
void exitType(const Type* Ty) override {
if (shouldUnderline()) {
W.setUnderline(false);
W.setUnderline(false); // FIXME Should set to old value
}
}
void visitAppType(const TApp *Ty) override {
auto Y = static_cast<const TApp*>(Ty);
void visitAppType(const TApp& Ty) override {
Path.push_back(TypeIndex::forAppOpType());
visit(Y->Op);
visit(Ty.Op);
Path.pop_back();
W.write(" ");
Path.push_back(TypeIndex::forAppArgType());
visit(Y->Arg);
visit(Ty.Arg);
Path.pop_back();
}
void visitVarType(const TVar* Ty) override {
if (Ty->getVarKind() == VarKind::Rigid) {
W.write(static_cast<const TVarRigid*>(Ty)->Name);
void visitVarType(const TVar& Ty) override {
if (Ty.isRigid()) {
W.write(*Ty.Name);
return;
}
W.write("a");
W.write(Ty->Id);
W.write(Ty.Id);
}
void visitConType(const TCon *Ty) override {
W.write(Ty->DisplayName);
void visitConType(const TCon& Ty) override {
W.write(Ty.DisplayName);
}
void visitArrowType(const TArrow* Ty) override {
void visitArrowType(const TArrow& Ty) override {
Path.push_back(TypeIndex::forArrowParamType());
visit(Ty->ParamType);
visit(Ty.ParamType);
Path.pop_back();
W.write(" -> ");
Path.push_back(TypeIndex::forArrowReturnType());
visit(Ty->ReturnType);
visit(Ty.ReturnType);
Path.pop_back();
}
void visitTupleType(const TTuple *Ty) override {
void visitTupleType(const TTuple& Ty) override {
W.write("(");
if (Ty->ElementTypes.size()) {
auto Iter = Ty->ElementTypes.begin();
if (Ty.ElementTypes.size()) {
auto Iter = Ty.ElementTypes.begin();
Path.push_back(TypeIndex::forTupleElement(0));
visit(*Iter++);
Path.pop_back();
std::size_t I = 1;
while (Iter != Ty->ElementTypes.end()) {
while (Iter != Ty.ElementTypes.end()) {
W.write(", ");
Path.push_back(TypeIndex::forTupleElement(I++));
visit(*Iter++);
@ -617,47 +617,47 @@ namespace bolt {
W.write(")");
}
void visitTupleIndexType(const TTupleIndex *Ty) override {
void visitTupleIndexType(const TTupleIndex& Ty) override {
Path.push_back(TypeIndex::forTupleIndexType());
visit(Ty->Ty);
visit(Ty.Ty);
Path.pop_back();
W.write(".");
W.write(Ty->I);
W.write(Ty.I);
}
void visitNilType(const TNil *Ty) override {
void visitNilType(const TNil& Ty) override {
W.write("{}");
}
void visitAbsentType(const TAbsent *Ty) override {
void visitAbsentType(const TAbsent& Ty) override {
W.write("Abs");
}
void visitPresentType(const TPresent *Ty) override {
void visitPresentType(const TPresent& Ty) override {
Path.push_back(TypeIndex::forPresentType());
visit(Ty->Ty);
visit(Ty.Ty);
Path.pop_back();
}
void visitFieldType(const TField* Ty) override {
void visitFieldType(const TField& Ty) override {
W.write("{ ");
W.write(Ty->Name);
W.write(Ty.Name);
W.write(": ");
Path.push_back(TypeIndex::forFieldType());
visit(Ty->Ty);
visit(Ty.Ty);
Path.pop_back();
auto Ty2 = Ty->RestTy;
auto Ty2 = Ty.RestTy;
Path.push_back(TypeIndex::forFieldRest());
std::size_t I = 1;
while (Ty2->getKind() == TypeKind::Field) {
auto Y = static_cast<const TField*>(Ty2);
while (Ty2->isField()) {
auto Y = Ty2->asField();
W.write("; ");
W.write(Y->Name);
W.write(Y.Name);
W.write(": ");
Path.push_back(TypeIndex::forFieldType());
visit(Y->Ty);
visit(Y.Ty);
Path.pop_back();
Ty2 = Y->RestTy;
Ty2 = Y.RestTy;
Path.push_back(TypeIndex::forFieldRest());
++I;
}
@ -730,7 +730,7 @@ namespace bolt {
case DiagnosticKind::BindingNotFound:
{
auto E = static_cast<const BindingNotFoundDiagnostic&>(D);
auto& E = static_cast<const BindingNotFoundDiagnostic&>(D);
writePrefix(E);
write("binding ");
writeBinding(E.Name);
@ -746,7 +746,7 @@ namespace bolt {
case DiagnosticKind::UnexpectedToken:
{
auto E = static_cast<const UnexpectedTokenDiagnostic&>(D);
auto& E = static_cast<const UnexpectedTokenDiagnostic&>(D);
writePrefix(E);
writeLoc(E.File, E.Actual->getStartLoc());
write(" expected ");
@ -780,7 +780,7 @@ namespace bolt {
case DiagnosticKind::UnexpectedString:
{
auto E = static_cast<const UnexpectedStringDiagnostic&>(D);
auto& E = static_cast<const UnexpectedStringDiagnostic&>(D);
writePrefix(E);
writeLoc(E.File, E.Location);
write(" unexpected '");
@ -806,7 +806,7 @@ namespace bolt {
case DiagnosticKind::UnificationError:
{
auto E = static_cast<const UnificationErrorDiagnostic&>(D);
auto& E = static_cast<const UnificationErrorDiagnostic&>(D);
auto Left = E.OrigLeft->resolve(E.LeftPath);
auto Right = E.OrigRight->resolve(E.RightPath);
writePrefix(E);
@ -857,7 +857,7 @@ namespace bolt {
case DiagnosticKind::TypeclassMissing:
{
auto E = static_cast<const TypeclassMissingDiagnostic&>(D);
auto& E = static_cast<const TypeclassMissingDiagnostic&>(D);
writePrefix(E);
write("the type class ");
writeTypeclassSignature(E.Sig);
@ -869,7 +869,7 @@ namespace bolt {
case DiagnosticKind::InstanceNotFound:
{
auto E = static_cast<const InstanceNotFoundDiagnostic&>(D);
auto& E = static_cast<const InstanceNotFoundDiagnostic&>(D);
writePrefix(E);
write("a type class instance ");
writeTypeclassName(E.TypeclassName);
@ -883,7 +883,7 @@ namespace bolt {
case DiagnosticKind::TupleIndexOutOfRange:
{
auto E = static_cast<const TupleIndexOutOfRangeDiagnostic&>(D);
auto& E = static_cast<const TupleIndexOutOfRangeDiagnostic&>(D);
writePrefix(E);
write("the index ");
writeType(E.I);
@ -894,7 +894,7 @@ namespace bolt {
case DiagnosticKind::InvalidTypeToTypeclass:
{
auto E = static_cast<const InvalidTypeToTypeclassDiagnostic&>(D);
auto& E = static_cast<const InvalidTypeToTypeclassDiagnostic&>(D);
writePrefix(E);
write("the type ");
writeType(E.Actual);
@ -911,7 +911,7 @@ namespace bolt {
case DiagnosticKind::FieldNotFound:
{
auto E = static_cast<const FieldNotFoundDiagnostic&>(D);
auto& E = static_cast<const FieldNotFoundDiagnostic&>(D);
writePrefix(E);
write("the field '");
write(E.Name);
@ -921,6 +921,16 @@ namespace bolt {
break;
}
case DiagnosticKind::NotATuple:
{
auto& E = static_cast<const NotATupleDiagnostic&>(D);
writePrefix(E);
write("the type ");
writeType(E.Ty);
write(" is not a tuple.\n");
break;
}
}
}

View file

@ -1,9 +1,8 @@
#include "zen/config.hpp"
#include "zen/range.hpp"
#include "bolt/Common.hpp"
#include "bolt/Type.hpp"
#include <cwchar>
#include <sys/wait.h>
#include <vector>
namespace bolt {
@ -13,13 +12,13 @@ namespace bolt {
}
ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1);
return Params[0]->Id < Other.Params[0]->Id;
return Params[0]->asCon().Id < Other.Params[0]->asCon().Id;
}
bool TypeclassSignature::operator==(const TypeclassSignature& Other) const {
ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1);
return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id;
return Id == Other.Id && Params[0]->asCon().Id == Other.Params[0]->asCon().Id;
}
bool TypeIndex::operator==(const TypeIndex& Other) const noexcept {
@ -35,34 +34,120 @@ namespace bolt {
}
}
void TypeIndex::advance(const Type* Ty) {
switch (Kind) {
case TypeIndexKind::End:
break;
case TypeIndexKind::AppOpType:
Kind = TypeIndexKind::AppArgType;
break;
case TypeIndexKind::ArrowParamType:
Kind = TypeIndexKind::ArrowReturnType;
break;
case TypeIndexKind::ArrowReturnType:
Kind = TypeIndexKind::End;
break;
case TypeIndexKind::FieldType:
Kind = TypeIndexKind::FieldRestType;
break;
case TypeIndexKind::FieldRestType:
case TypeIndexKind::TupleIndexType:
case TypeIndexKind::PresentType:
case TypeIndexKind::AppArgType:
case TypeIndexKind::TupleElement:
{
auto Tuple = cast<TTuple>(Ty);
if (I+1 < Tuple->ElementTypes.size()) {
++I;
} else {
Kind = TypeIndexKind::End;
bool TCon::operator==(const TCon& Other) const {
return Id == Other.Id;
}
bool TApp::operator==(const TApp& Other) const {
return *Op == *Other.Op && *Arg == *Other.Arg;
}
bool TVar::operator==(const TVar& Other) const {
return Id == Other.Id;
}
bool TArrow::operator==(const TArrow& Other) const {
return *ParamType == *Other.ParamType
&& *ReturnType == *Other.ReturnType;
}
bool TTuple::operator==(const TTuple& Other) const {
for (auto [T1, T2]: zen::zip(ElementTypes, Other.ElementTypes)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
bool TTupleIndex::operator==(const TTupleIndex& Other) const {
return *Ty == *Other.Ty && I == Other.I;
}
bool TNil::operator==(const TNil& Other) const {
return true;
}
bool TField::operator==(const TField& Other) const {
return Name == Other.Name && *Ty == *Other.Ty && *RestTy == *Other.RestTy;
}
bool TAbsent::operator==(const TAbsent& Other) const {
return true;
}
bool TPresent::operator==(const TPresent& Other) const {
return *Ty == *Other.Ty;
}
bool Type::operator==(const Type& Other) const {
if (Kind != Other.Kind) {
return false;
}
switch (Kind) {
case TypeKind::Var:
return Var == Other.Var;
case TypeKind::Con:
return Con == Other.Con;
case TypeKind::Present:
return Present == Other.Present;
case TypeKind::Absent:
return Absent == Other.Absent;
case TypeKind::Arrow:
return Arrow == Other.Arrow;
case TypeKind::Field:
return Field == Other.Field;
case TypeKind::Nil:
return Nil == Other.Nil;
case TypeKind::Tuple:
return Tuple == Other.Tuple;
case TypeKind::TupleIndex:
return TupleIndex == Other.TupleIndex;
case TypeKind::App:
return App == Other.App;
}
}
void Type::visitEachChild(std::function<void(Type*)> Proc) {
switch (Kind) {
case TypeKind::Var:
case TypeKind::Absent:
case TypeKind::Nil:
case TypeKind::Con:
break;
case TypeKind::Arrow:
{
Proc(Arrow.ParamType);
Proc(Arrow.ReturnType);
break;
}
case TypeKind::Tuple:
{
for (auto I = 0; I < Tuple.ElementTypes.size(); ++I) {
Proc(Tuple.ElementTypes[I]);
}
break;
}
case TypeKind::App:
{
Proc(App.Op);
Proc(App.Arg);
break;
}
case TypeKind::Field:
{
Proc(Field.Ty);
Proc(Field.RestTy);
break;
}
case TypeKind::Present:
{
Proc(Present.Ty);
break;
}
case TypeKind::TupleIndex:
{
Proc(TupleIndex.Ty);
break;
}
}
@ -81,49 +166,49 @@ namespace bolt {
return Ty2;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(Ty2);
auto Arrow = Ty2->asArrow();
bool Changed = false;
Type* NewParamType = Arrow->ParamType->rewrite(Fn);
if (NewParamType != Arrow->ParamType) {
Type* NewParamType = Arrow.ParamType->rewrite(Fn, Recursive);
if (NewParamType != Arrow.ParamType) {
Changed = true;
}
auto NewRetTy = Arrow->ReturnType->rewrite(Fn);
if (NewRetTy != Arrow->ReturnType) {
auto NewRetTy = Arrow.ReturnType->rewrite(Fn, Recursive);
if (NewRetTy != Arrow.ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamType, NewRetTy) : Ty2;
return Changed ? new Type(TArrow(NewParamType, NewRetTy)) : Ty2;
}
case TypeKind::Con:
return Ty2;
case TypeKind::App:
{
auto App = static_cast<TApp*>(Ty2);
auto NewOp = App->Op->rewrite(Fn);
auto NewArg = App->Arg->rewrite(Fn);
if (NewOp == App->Op && NewArg == App->Arg) {
return App;
auto App = Ty2->asApp();
auto NewOp = App.Op->rewrite(Fn, Recursive);
auto NewArg = App.Arg->rewrite(Fn, Recursive);
if (NewOp == App.Op && NewArg == App.Arg) {
return Ty2;
}
return new TApp(NewOp, NewArg);
return new Type(TApp(NewOp, NewArg));
}
case TypeKind::TupleIndex:
{
auto Tuple = static_cast<TTupleIndex*>(Ty2);
auto NewTy = Tuple->Ty->rewrite(Fn);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
auto Tuple = Ty2->asTupleIndex();
auto NewTy = Tuple.Ty->rewrite(Fn, Recursive);
return NewTy != Tuple.Ty ? new Type(TTupleIndex(NewTy, Tuple.I)) : Ty2;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(Ty2);
auto Tuple = Ty2->asTuple();
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = Ty->rewrite(Fn);
for (auto Ty: Tuple.ElementTypes) {
auto NewElementType = Ty->rewrite(Fn, Recursive);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : Ty2;
return Changed ? new Type(TTuple(NewElementTypes)) : Ty2;
}
case TypeKind::Nil:
return Ty2;
@ -131,272 +216,77 @@ namespace bolt {
return Ty2;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(Ty2);
auto Field = Ty2->asField();
bool Changed = false;
auto NewTy = Field->Ty->rewrite(Fn);
if (NewTy != Field->Ty) {
auto NewTy = Field.Ty->rewrite(Fn, Recursive);
if (NewTy != Field.Ty) {
Changed = true;
}
auto NewRestTy = Field->RestTy->rewrite(Fn);
if (NewRestTy != Field->RestTy) {
auto NewRestTy = Field.RestTy->rewrite(Fn, Recursive);
if (NewRestTy != Field.RestTy) {
Changed = true;
}
return Changed ? new TField(Field->Name, NewTy, NewRestTy) : Ty2;
return Changed ? new Type(TField(Field.Name, NewTy, NewRestTy)) : Ty2;
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(Ty2);
auto NewTy = Present->Ty->rewrite(Fn);
if (NewTy == Present->Ty) {
auto Present = Ty2->asPresent();
auto NewTy = Present.Ty->rewrite(Fn, Recursive);
if (NewTy == Present.Ty) {
return Ty2;
}
return new TPresent(NewTy);
return new Type(TPresent(NewTy));
}
}
}
void Type::addTypeVars(TVSet& TVs) {
switch (Kind) {
case TypeKind::Var:
TVs.emplace(static_cast<TVar*>(this));
break;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
Arrow->ParamType->addTypeVars(TVs);
Arrow->ReturnType->addTypeVars(TVs);
break;
}
case TypeKind::Con:
break;
case TypeKind::App:
{
auto App = static_cast<TApp*>(this);
App->Op->addTypeVars(TVs);
App->Arg->addTypeVars(TVs);
break;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
Index->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Tuple->ElementTypes) {
Ty->addTypeVars(TVs);
}
break;
}
case TypeKind::Nil:
break;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(this);
Field->Ty->addTypeVars(TVs);
Field->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(this);
Present->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Absent:
break;
}
}
bool Type::hasTypeVar(const TVar* TV) {
switch (Kind) {
case TypeKind::Var:
return static_cast<TVar*>(this)->Id == TV->Id;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
return Arrow->ParamType->hasTypeVar(TV) || Arrow->ReturnType->hasTypeVar(TV);
}
case TypeKind::Con:
return false;
case TypeKind::App:
{
auto App = static_cast<TApp*>(this);
return App->Op->hasTypeVar(TV) || App->Arg->hasTypeVar(TV);
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(this);
return Index->Ty->hasTypeVar(TV);
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Tuple->ElementTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
case TypeKind::Nil:
return false;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(this);
return Field->Ty->hasTypeVar(TV) || Field->RestTy->hasTypeVar(TV);
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(this);
return Present->Ty->hasTypeVar(TV);
}
case TypeKind::Absent:
return false;
}
}
Type* Type::solve() {
return rewrite([](auto Ty) {
if (Ty->getKind() == TypeKind::Var) {
return static_cast<TVar*>(Ty)->find();
}
return Ty;
});
}
Type* Type::substitute(const TVSub &Sub) {
return rewrite([&](auto Ty) {
if (isa<TVar>(Ty)) {
auto TV = static_cast<TVar*>(Ty);
auto Match = Sub.find(TV);
if (Ty->isVar()) {
auto Match = Sub.find(Ty);
return Match != Sub.end() ? Match->second->substitute(Sub) : Ty;
}
return Ty;
});
}, false);
}
Type* Type::resolve(const TypeIndex& Index) const noexcept {
switch (Index.Kind) {
case TypeIndexKind::PresentType:
return cast<TPresent>(this)->Ty;
return this->asPresent().Ty;
case TypeIndexKind::AppOpType:
return cast<TApp>(this)->Op;
return this->asApp().Op;
case TypeIndexKind::AppArgType:
return cast<TApp>(this)->Arg;
return this->asApp().Arg;
case TypeIndexKind::TupleIndexType:
return cast<TTupleIndex>(this)->Ty;
return this->asTupleIndex().Ty;
case TypeIndexKind::TupleElement:
return cast<TTuple>(this)->ElementTypes[Index.I];
return this->asTuple().ElementTypes[Index.I];
case TypeIndexKind::ArrowParamType:
return cast<TArrow>(this)->ParamType;
return this->asArrow().ParamType;
case TypeIndexKind::ArrowReturnType:
return cast<TArrow>(this)->ReturnType;
return this->asArrow().ReturnType;
case TypeIndexKind::FieldType:
return cast<TField>(this)->Ty;
return this->asField().Ty;
case TypeIndexKind::FieldRestType:
return cast<TField>(this)->RestTy;
return this->asField().RestTy;
case TypeIndexKind::End:
ZEN_UNREACHABLE
}
ZEN_UNREACHABLE
}
bool Type::operator==(const Type& Other) const noexcept {
switch (Kind) {
case TypeKind::Var:
if (Other.Kind != TypeKind::Var) {
return false;
}
return static_cast<const TVar*>(this)->Id == static_cast<const TVar&>(Other).Id;
case TypeKind::Tuple:
{
if (Other.Kind != TypeKind::Tuple) {
return false;
}
auto A = static_cast<const TTuple&>(*this);
auto B = static_cast<const TTuple&>(Other);
if (A.ElementTypes.size() != B.ElementTypes.size()) {
return false;
}
for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
case TypeKind::TupleIndex:
{
if (Other.Kind != TypeKind::TupleIndex) {
return false;
}
auto A = static_cast<const TTupleIndex&>(*this);
auto B = static_cast<const TTupleIndex&>(Other);
return A.I == B.I && *A.Ty == *B.Ty;
}
case TypeKind::Con:
{
if (Other.Kind != TypeKind::Con) {
return false;
}
auto A = static_cast<const TCon&>(*this);
auto B = static_cast<const TCon&>(Other);
if (A.Id != B.Id) {
return false;
}
return true;
}
case TypeKind::App:
{
if (Other.Kind != TypeKind::App) {
return false;
}
auto A = static_cast<const TApp&>(*this);
auto B = static_cast<const TApp&>(Other);
return *A.Op == *B.Op && *A.Arg == *B.Arg;
}
case TypeKind::Arrow:
{
if (Other.Kind != TypeKind::Arrow) {
return false;
}
auto A = static_cast<const TArrow&>(*this);
auto B = static_cast<const TArrow&>(Other);
return *A.ParamType == *B.ParamType && *A.ReturnType == *B.ReturnType;
}
case TypeKind::Absent:
if (Other.Kind != TypeKind::Absent) {
return false;
}
return true;
case TypeKind::Nil:
if (Other.Kind != TypeKind::Nil) {
return false;
}
return true;
case TypeKind::Present:
{
if (Other.Kind != TypeKind::Present) {
return false;
}
auto A = static_cast<const TPresent&>(*this);
auto B = static_cast<const TPresent&>(Other);
return *A.Ty == *B.Ty;
}
case TypeKind::Field:
{
if (Other.Kind != TypeKind::Field) {
return false;
}
auto A = static_cast<const TField&>(*this);
auto B = static_cast<const TField&>(Other);
return A.Name == B.Name && *A.Ty == *B.Ty && *A.RestTy == *B.RestTy;
}
TVSet Type::getTypeVars() {
TVSet Out;
std::function<void(Type*)> visit = [&](Type* Ty) {
if (Ty->isVar()) {
Out.emplace(Ty);
return;
}
Ty->visitEachChild(visit);
};
visit(this);
return Out;
}
TypeIterator Type::begin() {
@ -407,14 +297,13 @@ namespace bolt {
return TypeIterator { this, getEndIndex() };
}
TypeIndex Type::getStartIndex() {
TypeIndex Type::getStartIndex() const {
switch (Kind) {
case TypeKind::Arrow:
return TypeIndex::forArrowParamType();
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
if (Tuple->ElementTypes.empty()) {
if (asTuple().ElementTypes.empty()) {
return TypeIndex(TypeIndexKind::End);
}
return TypeIndex::forTupleElement(0);
@ -426,29 +315,38 @@ namespace bolt {
}
}
TypeIndex Type::getEndIndex() {
TypeIndex Type::getEndIndex() const {
return TypeIndex(TypeIndexKind::End);
}
inline Type* TVar::find() {
TVar* Curr = this;
for (;;) {
auto Keep = Curr->Parent;
if (Keep->getKind() != TypeKind::Var || Keep == Curr) {
return Keep;
}
auto TV = static_cast<TVar*>(Keep);
Curr->Parent = TV->Parent;
Curr = TV;
bool Type::hasTypeVar(Type* TV) const {
switch (Kind) {
case TypeKind::Var:
return Var.Id == TV->asVar().Id;
case TypeKind::Con:
case TypeKind::Absent:
case TypeKind::Nil:
return false;
case TypeKind::App:
return App.Op->hasTypeVar(TV) || App.Arg->hasTypeVar(TV);
case TypeKind::Tuple:
for (auto Ty: Tuple.ElementTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
void TVar::set(Type* Ty) {
auto Root = find();
// It is not possible to set a solution twice.
ZEN_ASSERT(Root->getKind() == TypeKind::Var);
static_cast<TVar*>(Root)->Parent = Ty;
return false;
case TypeKind::TupleIndex:
return TupleIndex.Ty->hasTypeVar(TV);
case TypeKind::Field:
return Field.Ty->hasTypeVar(TV) || Field.RestTy->hasTypeVar(TV);
case TypeKind::Arrow:
return Arrow.ParamType->hasTypeVar(TV) || Arrow.ReturnType->hasTypeVar(TV);
case TypeKind::Present:
return Present.Ty->hasTypeVar(TV);
}
}
}