Split up Checker.hpp and make room for better type mismatch errors
This commit is contained in:
parent
508ef40bdf
commit
302823ac9b
7 changed files with 821 additions and 428 deletions
|
@ -20,6 +20,7 @@ add_library(
|
|||
src/Diagnostics.cc
|
||||
src/Scanner.cc
|
||||
src/Parser.cc
|
||||
src/Types.cc
|
||||
src/Checker.cc
|
||||
src/IPRGraph.cc
|
||||
)
|
||||
|
|
|
@ -6,170 +6,15 @@
|
|||
#include "bolt/ByteString.hpp"
|
||||
#include "bolt/Common.hpp"
|
||||
#include "bolt/CST.hpp"
|
||||
#include "bolt/Diagnostics.hpp"
|
||||
#include "bolt/Type.hpp"
|
||||
|
||||
#include <istream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class DiagnosticEngine;
|
||||
class Node;
|
||||
|
||||
class Type;
|
||||
class TVar;
|
||||
|
||||
using TVSub = std::unordered_map<TVar*, Type*>;
|
||||
using TVSet = std::unordered_set<TVar*>;
|
||||
|
||||
using TypeclassContext = std::unordered_set<TypeclassId>;
|
||||
|
||||
enum class TypeKind : unsigned char {
|
||||
Var,
|
||||
Con,
|
||||
Arrow,
|
||||
Tuple,
|
||||
TupleIndex,
|
||||
};
|
||||
|
||||
class Type {
|
||||
|
||||
const TypeKind Kind;
|
||||
|
||||
protected:
|
||||
|
||||
inline Type(TypeKind Kind):
|
||||
Kind(Kind) {}
|
||||
|
||||
public:
|
||||
|
||||
bool hasTypeVar(const TVar* TV);
|
||||
|
||||
void addTypeVars(TVSet& TVs);
|
||||
|
||||
inline TVSet getTypeVars() {
|
||||
TVSet Out;
|
||||
addTypeVars(Out);
|
||||
return Out;
|
||||
}
|
||||
|
||||
Type* substitute(const TVSub& Sub);
|
||||
|
||||
inline TypeKind getKind() const noexcept {
|
||||
return Kind;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TCon : public Type {
|
||||
public:
|
||||
|
||||
const size_t Id;
|
||||
std::vector<Type*> Args;
|
||||
ByteString DisplayName;
|
||||
|
||||
inline TCon(const size_t Id, std::vector<Type*> Args, ByteString DisplayName):
|
||||
Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Con;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
enum class VarKind {
|
||||
Rigid,
|
||||
Unification,
|
||||
};
|
||||
|
||||
class TVar : public Type {
|
||||
public:
|
||||
|
||||
const size_t Id;
|
||||
VarKind VK;
|
||||
|
||||
TypeclassContext Contexts;
|
||||
|
||||
inline TVar(size_t Id, VarKind VK):
|
||||
Type(TypeKind::Var), Id(Id), VK(VK) {}
|
||||
|
||||
inline VarKind getVarKind() const noexcept {
|
||||
return VK;
|
||||
}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Var;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TVarRigid : public TVar {
|
||||
public:
|
||||
|
||||
ByteString Name;
|
||||
|
||||
inline TVarRigid(size_t Id, ByteString Name):
|
||||
TVar(Id, VarKind::Rigid), Name(Name) {}
|
||||
|
||||
};
|
||||
|
||||
class TArrow : public Type {
|
||||
public:
|
||||
|
||||
std::vector<Type*> ParamTypes;
|
||||
Type* ReturnType;
|
||||
|
||||
inline TArrow(
|
||||
std::vector<Type*> ParamTypes,
|
||||
Type* ReturnType
|
||||
): Type(TypeKind::Arrow),
|
||||
ParamTypes(ParamTypes),
|
||||
ReturnType(ReturnType) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Arrow;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TTuple : public Type {
|
||||
public:
|
||||
|
||||
std::vector<Type*> ElementTypes;
|
||||
|
||||
inline TTuple(std::vector<Type*> ElementTypes):
|
||||
Type(TypeKind::Tuple), ElementTypes(ElementTypes) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Tuple;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TTupleIndex : public Type {
|
||||
public:
|
||||
|
||||
Type* Ty;
|
||||
std::size_t I;
|
||||
|
||||
inline TTupleIndex(Type* Ty, std::size_t I):
|
||||
Type(TypeKind::TupleIndex), Ty(Ty), I(I) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::TupleIndex;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// template<typename T>
|
||||
// struct DerefHash {
|
||||
// std::size_t operator()(const T& Value) const noexcept {
|
||||
// return std::hash<decltype(*Value)>{}(*Value);
|
||||
// }
|
||||
// };
|
||||
|
||||
class Constraint;
|
||||
|
||||
|
@ -354,6 +199,8 @@ namespace bolt {
|
|||
|
||||
std::unordered_map<Node*, InferContext*> CallGraph;
|
||||
|
||||
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
|
||||
|
||||
Type* BoolType;
|
||||
Type* IntType;
|
||||
Type* StringType;
|
||||
|
@ -412,20 +259,43 @@ namespace bolt {
|
|||
|
||||
Type* instantiate(Scheme* S, Node* Source);
|
||||
|
||||
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
|
||||
std::vector<TypeclassContext> findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source);
|
||||
void propagateClasses(TypeclassContext& Classes, Type* Ty, Node* Source);
|
||||
void propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source);
|
||||
|
||||
void checkTypeclassSigs(Node* N);
|
||||
std::vector<TypeclassContext> findInstanceContext(TCon* Ty, TypeclassId& Class);
|
||||
void propagateClasses(TypeclassContext& Classes, Type* Ty);
|
||||
void propagateClassTycon(TypeclassId& Class, TCon* Ty);
|
||||
|
||||
Type* simplify(Type* Ty);
|
||||
void join(TVar* A, Type* B, Node* Source);
|
||||
bool unify(Type* A, Type* B, Node* Source);
|
||||
|
||||
/**
|
||||
* Assign a type to a unification variable.
|
||||
*
|
||||
* If there are class constraints, those are propagated.
|
||||
*
|
||||
* If this type variable is solved during inference, it will be removed from
|
||||
* the inference context.
|
||||
*
|
||||
* Other side effects may occur.
|
||||
*/
|
||||
void join(TVar* A, Type* B);
|
||||
|
||||
Type* OrigLeft;
|
||||
Type* OrigRight;
|
||||
TypePath LeftPath;
|
||||
TypePath RightPath;
|
||||
Node* Source;
|
||||
|
||||
bool unify(Type* A, Type* B);
|
||||
|
||||
void unifyError();
|
||||
void solveCEqual(CEqual* C);
|
||||
|
||||
void solve(Constraint* Constraint, TVSub& Solution);
|
||||
|
||||
/**
|
||||
* Verifies that type class signatures on type asserts in let-declarations
|
||||
* correctly declare the right type classes.
|
||||
*/
|
||||
void checkTypeclassSigs(Node* N);
|
||||
|
||||
public:
|
||||
|
||||
Checker(const LanguageConfig& Config, DiagnosticEngine& DE);
|
||||
|
|
|
@ -9,27 +9,10 @@
|
|||
#include "bolt/ByteString.hpp"
|
||||
#include "bolt/String.hpp"
|
||||
#include "bolt/CST.hpp"
|
||||
#include "bolt/Type.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Type;
|
||||
class TCon;
|
||||
class TVar;
|
||||
class TTuple;
|
||||
|
||||
using TypeclassId = ByteString;
|
||||
|
||||
struct TypeclassSignature {
|
||||
|
||||
using TypeclassId = ByteString;
|
||||
TypeclassId Id;
|
||||
std::vector<TVar*> Params;
|
||||
|
||||
bool operator<(const TypeclassSignature& Other) const;
|
||||
bool operator==(const TypeclassSignature& Other) const;
|
||||
|
||||
};
|
||||
|
||||
enum class DiagnosticKind : unsigned char {
|
||||
UnexpectedToken,
|
||||
UnexpectedString,
|
||||
|
@ -95,13 +78,15 @@ namespace bolt {
|
|||
|
||||
class UnificationErrorDiagnostic : public Diagnostic {
|
||||
public:
|
||||
|
||||
|
||||
Type* Left;
|
||||
Type* Right;
|
||||
TypePath LeftPath;
|
||||
TypePath RightPath;
|
||||
Node* Source;
|
||||
|
||||
inline UnificationErrorDiagnostic(Type* Left, Type* Right, Node* Source):
|
||||
Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), Source(Source) {}
|
||||
inline UnificationErrorDiagnostic(Type* Left, Type* Right, TypePath LeftPath, TypePath RightPath, Node* Source):
|
||||
Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {}
|
||||
|
||||
};
|
||||
|
||||
|
|
282
include/bolt/Type.hpp
Normal file
282
include/bolt/Type.hpp
Normal file
|
@ -0,0 +1,282 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "bolt/ByteString.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Type;
|
||||
class TVar;
|
||||
|
||||
using TypeclassId = ByteString;
|
||||
|
||||
using TypeclassContext = std::unordered_set<TypeclassId>;
|
||||
|
||||
struct TypeclassSignature {
|
||||
|
||||
using TypeclassId = ByteString;
|
||||
TypeclassId Id;
|
||||
std::vector<TVar*> Params;
|
||||
|
||||
bool operator<(const TypeclassSignature& Other) const;
|
||||
bool operator==(const TypeclassSignature& Other) const;
|
||||
|
||||
};
|
||||
|
||||
enum class TypeIndexKind {
|
||||
ArrowParamType,
|
||||
ArrowReturnType,
|
||||
ConArg,
|
||||
TupleElement,
|
||||
End,
|
||||
};
|
||||
|
||||
class TypeIndex {
|
||||
protected:
|
||||
|
||||
friend class Type;
|
||||
friend class TypeIterator;
|
||||
|
||||
TypeIndexKind Kind;
|
||||
|
||||
union {
|
||||
std::size_t I;
|
||||
};
|
||||
|
||||
TypeIndex(TypeIndexKind Kind):
|
||||
Kind(Kind) {}
|
||||
|
||||
TypeIndex(TypeIndexKind Kind, std::size_t I):
|
||||
Kind(Kind), I(I) {}
|
||||
|
||||
public:
|
||||
|
||||
bool operator==(const TypeIndex& Other) const noexcept;
|
||||
|
||||
void advance(const Type* Ty);
|
||||
|
||||
static TypeIndex forArrowReturnType() {
|
||||
return { TypeIndexKind::ArrowReturnType };
|
||||
}
|
||||
|
||||
static TypeIndex forArrowParamType(std::size_t I) {
|
||||
return { TypeIndexKind::ArrowParamType, I };
|
||||
}
|
||||
|
||||
static TypeIndex forConArg(std::size_t I) {
|
||||
return { TypeIndexKind::ConArg, I };
|
||||
}
|
||||
|
||||
static TypeIndex forTupleElement(std::size_t I) {
|
||||
return { TypeIndexKind::TupleElement, I };
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TypeIterator {
|
||||
|
||||
friend class Type;
|
||||
|
||||
Type* Ty;
|
||||
TypeIndex Index;
|
||||
|
||||
TypeIterator(Type* Ty, TypeIndex Index):
|
||||
Ty(Ty), Index(Index) {}
|
||||
|
||||
public:
|
||||
|
||||
TypeIterator& operator++() noexcept {
|
||||
Index.advance(Ty);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(const TypeIterator& Other) const noexcept {
|
||||
return Ty == Other.Ty && Index == Other.Index;
|
||||
}
|
||||
|
||||
Type* operator*() {
|
||||
return Ty;
|
||||
}
|
||||
|
||||
TypeIndex getIndex() const noexcept {
|
||||
return Index;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
using TypePath = std::vector<TypeIndex>;
|
||||
|
||||
using TVSub = std::unordered_map<TVar*, Type*>;
|
||||
using TVSet = std::unordered_set<TVar*>;
|
||||
|
||||
enum class TypeKind : unsigned char {
|
||||
Var,
|
||||
Con,
|
||||
Arrow,
|
||||
Tuple,
|
||||
TupleIndex,
|
||||
};
|
||||
|
||||
class Type {
|
||||
|
||||
const TypeKind Kind;
|
||||
|
||||
protected:
|
||||
|
||||
inline Type(TypeKind Kind):
|
||||
Kind(Kind) {}
|
||||
|
||||
public:
|
||||
|
||||
inline TypeKind getKind() const noexcept {
|
||||
return Kind;
|
||||
}
|
||||
|
||||
bool hasTypeVar(const TVar* TV);
|
||||
|
||||
void addTypeVars(TVSet& TVs);
|
||||
|
||||
inline TVSet getTypeVars() {
|
||||
TVSet Out;
|
||||
addTypeVars(Out);
|
||||
return Out;
|
||||
}
|
||||
|
||||
Type* substitute(const TVSub& Sub);
|
||||
|
||||
TypeIterator begin();
|
||||
TypeIterator end();
|
||||
|
||||
TypeIndex getStartIndex();
|
||||
TypeIndex getEndIndex();
|
||||
|
||||
Type* resolve(const TypeIndex& Index) const noexcept;
|
||||
|
||||
Type* resolve(const TypePath& Path) noexcept {
|
||||
Type* Ty = this;
|
||||
for (auto El: Path) {
|
||||
Ty = Ty->resolve(El);
|
||||
}
|
||||
return Ty;
|
||||
}
|
||||
|
||||
bool operator==(const Type& Other) const noexcept;
|
||||
|
||||
bool operator!=(const Type& Other) const noexcept {
|
||||
return !(*this == Other);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TCon : public Type {
|
||||
public:
|
||||
|
||||
const size_t Id;
|
||||
std::vector<Type*> Args;
|
||||
ByteString DisplayName;
|
||||
|
||||
inline TCon(const size_t Id, std::vector<Type*> Args, ByteString DisplayName):
|
||||
Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Con;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
enum class VarKind {
|
||||
Rigid,
|
||||
Unification,
|
||||
};
|
||||
|
||||
class TVar : public Type {
|
||||
public:
|
||||
|
||||
const size_t Id;
|
||||
VarKind VK;
|
||||
|
||||
TypeclassContext Contexts;
|
||||
|
||||
inline TVar(size_t Id, VarKind VK):
|
||||
Type(TypeKind::Var), Id(Id), VK(VK) {}
|
||||
|
||||
inline VarKind getVarKind() const noexcept {
|
||||
return VK;
|
||||
}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Var;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TVarRigid : public TVar {
|
||||
public:
|
||||
|
||||
ByteString Name;
|
||||
|
||||
inline TVarRigid(size_t Id, ByteString Name):
|
||||
TVar(Id, VarKind::Rigid), Name(Name) {}
|
||||
|
||||
};
|
||||
|
||||
class TArrow : public Type {
|
||||
public:
|
||||
|
||||
std::vector<Type*> ParamTypes;
|
||||
Type* ReturnType;
|
||||
|
||||
inline TArrow(
|
||||
std::vector<Type*> ParamTypes,
|
||||
Type* ReturnType
|
||||
): Type(TypeKind::Arrow),
|
||||
ParamTypes(ParamTypes),
|
||||
ReturnType(ReturnType) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Arrow;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TTuple : public Type {
|
||||
public:
|
||||
|
||||
std::vector<Type*> ElementTypes;
|
||||
|
||||
inline TTuple(std::vector<Type*> ElementTypes):
|
||||
Type(TypeKind::Tuple), ElementTypes(ElementTypes) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Tuple;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class TTupleIndex : public Type {
|
||||
public:
|
||||
|
||||
Type* Ty;
|
||||
std::size_t I;
|
||||
|
||||
inline TTupleIndex(Type* Ty, std::size_t I):
|
||||
Type(TypeKind::TupleIndex), Ty(Ty), I(I) {}
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::TupleIndex;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// template<typename T>
|
||||
// struct DerefHash {
|
||||
// std::size_t operator()(const T& Value) const noexcept {
|
||||
// return std::hash<decltype(*Value)>{}(*Value);
|
||||
// }
|
||||
// };
|
||||
|
||||
}
|
351
src/Checker.cc
351
src/Checker.cc
|
@ -25,165 +25,6 @@ namespace bolt {
|
|||
|
||||
std::string describe(const Type* Ty);
|
||||
|
||||
bool TypeclassSignature::operator<(const TypeclassSignature& Other) const {
|
||||
if (Id < Other.Id) {
|
||||
return true;
|
||||
}
|
||||
ZEN_ASSERT(Params.size() == 1);
|
||||
ZEN_ASSERT(Other.Params.size() == 1);
|
||||
return Params[0]->Id < Other.Params[0]->Id;
|
||||
}
|
||||
|
||||
bool TypeclassSignature::operator==(const TypeclassSignature& Other) const {
|
||||
ZEN_ASSERT(Params.size() == 1);
|
||||
ZEN_ASSERT(Other.Params.size() == 1);
|
||||
return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id;
|
||||
}
|
||||
|
||||
void Type::addTypeVars(TVSet& TVs) {
|
||||
switch (Kind) {
|
||||
case TypeKind::Var:
|
||||
TVs.emplace(static_cast<TVar*>(this));
|
||||
break;
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
Ty->addTypeVars(TVs);
|
||||
}
|
||||
Arrow->ReturnType->addTypeVars(TVs);
|
||||
break;
|
||||
}
|
||||
case TypeKind::Con:
|
||||
{
|
||||
auto Con = static_cast<TCon*>(this);
|
||||
for (auto Ty: Con->Args) {
|
||||
Ty->addTypeVars(TVs);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Index = static_cast<TTupleIndex*>(this);
|
||||
Index->Ty->addTypeVars(TVs);
|
||||
break;
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
for (auto Ty: Tuple->ElementTypes) {
|
||||
Ty->addTypeVars(TVs);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Type::hasTypeVar(const TVar* TV) {
|
||||
switch (Kind) {
|
||||
case TypeKind::Var:
|
||||
return static_cast<TVar*>(this)->Id == TV->Id;
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
if (Ty->hasTypeVar(TV)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return Arrow->ReturnType->hasTypeVar(TV);
|
||||
}
|
||||
case TypeKind::Con:
|
||||
{
|
||||
auto Con = static_cast<TCon*>(this);
|
||||
for (auto Ty: Con->Args) {
|
||||
if (Ty->hasTypeVar(TV)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Index = static_cast<TTupleIndex*>(this);
|
||||
return Index->Ty->hasTypeVar(TV);
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
for (auto Ty: Tuple->ElementTypes) {
|
||||
if (Ty->hasTypeVar(TV)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Type* Type::substitute(const TVSub &Sub) {
|
||||
switch (Kind) {
|
||||
case TypeKind::Var:
|
||||
{
|
||||
auto TV = static_cast<TVar*>(this);
|
||||
auto Match = Sub.find(TV);
|
||||
return Match != Sub.end() ? Match->second->substitute(Sub) : this;
|
||||
}
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewParamTypes;
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
auto NewParamType = Ty->substitute(Sub);
|
||||
if (NewParamType != Ty) {
|
||||
Changed = true;
|
||||
}
|
||||
NewParamTypes.push_back(NewParamType);
|
||||
}
|
||||
auto NewRetTy = Arrow->ReturnType->substitute(Sub) ;
|
||||
if (NewRetTy != Arrow->ReturnType) {
|
||||
Changed = true;
|
||||
}
|
||||
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
|
||||
}
|
||||
case TypeKind::Con:
|
||||
{
|
||||
auto Con = static_cast<TCon*>(this);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewArgs;
|
||||
for (auto Arg: Con->Args) {
|
||||
auto NewArg = Arg->substitute(Sub);
|
||||
if (NewArg != Arg) {
|
||||
Changed = true;
|
||||
}
|
||||
NewArgs.push_back(NewArg);
|
||||
}
|
||||
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Tuple = static_cast<TTupleIndex*>(this);
|
||||
auto NewTy = Tuple->Ty->substitute(Sub);
|
||||
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewElementTypes;
|
||||
for (auto Ty: Tuple->ElementTypes) {
|
||||
auto NewElementType = Ty->substitute(Sub);
|
||||
if (NewElementType != Ty) {
|
||||
Changed = true;
|
||||
}
|
||||
NewElementTypes.push_back(NewElementType);
|
||||
}
|
||||
return Changed ? new TTuple(NewElementTypes) : this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Constraint* Constraint::substitute(const TVSub &Sub) {
|
||||
switch (Kind) {
|
||||
case ConstraintKind::Class:
|
||||
|
@ -1141,7 +982,7 @@ namespace bolt {
|
|||
ZEN_UNREACHABLE
|
||||
}
|
||||
|
||||
std::vector<TypeclassContext> Checker::findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source) {
|
||||
std::vector<TypeclassContext> Checker::findInstanceContext(TCon* Ty, TypeclassId& Class) {
|
||||
auto Match = InstanceMap.find(Class);
|
||||
std::vector<TypeclassContext> S;
|
||||
if (Match != InstanceMap.end()) {
|
||||
|
@ -1164,7 +1005,7 @@ namespace bolt {
|
|||
return S;
|
||||
}
|
||||
|
||||
void Checker::propagateClasses(std::unordered_set<TypeclassId>& Classes, Type* Ty, Node* Source) {
|
||||
void Checker::propagateClasses(std::unordered_set<TypeclassId>& Classes, Type* Ty) {
|
||||
if (llvm::isa<TVar>(Ty)) {
|
||||
auto TV = llvm::cast<TVar>(Ty);
|
||||
for (auto Class: Classes) {
|
||||
|
@ -1172,61 +1013,29 @@ namespace bolt {
|
|||
}
|
||||
} else if (llvm::isa<TCon>(Ty)) {
|
||||
for (auto Class: Classes) {
|
||||
propagateClassTycon(Class, llvm::cast<TCon>(Ty), Source);
|
||||
propagateClassTycon(Class, llvm::cast<TCon>(Ty));
|
||||
}
|
||||
} else if (!Classes.empty()) {
|
||||
DE.add<InvalidTypeToTypeclassDiagnostic>(Ty);
|
||||
}
|
||||
};
|
||||
|
||||
void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source) {
|
||||
auto S = findInstanceContext(Ty, Class, Source);
|
||||
void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty) {
|
||||
auto S = findInstanceContext(Ty, Class);
|
||||
for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) {
|
||||
propagateClasses(Classes, Arg, Source);
|
||||
propagateClasses(Classes, Arg);
|
||||
}
|
||||
};
|
||||
|
||||
class ArrowCursor {
|
||||
|
||||
std::stack<std::tuple<TArrow*, std::size_t>> Path;
|
||||
|
||||
public:
|
||||
|
||||
ArrowCursor(TArrow* Arr) {
|
||||
Path.push({ Arr, 0 });
|
||||
}
|
||||
|
||||
Type* next() {
|
||||
while (!Path.empty()) {
|
||||
auto& [Arr, I] = Path.top();
|
||||
Type* Ty;
|
||||
if (I == -1) {
|
||||
Path.pop();
|
||||
continue;
|
||||
}
|
||||
if (I == Arr->ParamTypes.size()) {
|
||||
I = -1;
|
||||
Ty = Arr->ReturnType;
|
||||
} else {
|
||||
Ty = Arr->ParamTypes[I];
|
||||
I++;
|
||||
}
|
||||
if (llvm::isa<TArrow>(Ty)) {
|
||||
Path.push({ static_cast<TArrow*>(Ty), 0 });
|
||||
} else {
|
||||
return Ty;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
void Checker::solveCEqual(CEqual* C) {
|
||||
std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl;
|
||||
if (!unify(C->Left, C->Right, C->Source)) {
|
||||
DE.add<UnificationErrorDiagnostic>(simplify(C->Left), simplify(C->Right), C->Source);
|
||||
}
|
||||
/* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */
|
||||
OrigLeft = C->Left;
|
||||
OrigRight = C->Right;
|
||||
Source = C->Source;
|
||||
unify(C->Left, C->Right);
|
||||
LeftPath = {};
|
||||
RightPath = {};
|
||||
/* DE.add<UnificationErrorDiagnostic>(simplify(C->Left), simplify(C->Right), C->Source); */
|
||||
}
|
||||
|
||||
Type* Checker::simplify(Type* Ty) {
|
||||
|
@ -1314,11 +1123,11 @@ namespace bolt {
|
|||
return Ty;
|
||||
}
|
||||
|
||||
void Checker::join(TVar* TV, Type* Ty, Node* Source) {
|
||||
void Checker::join(TVar* TV, Type* Ty) {
|
||||
|
||||
Solution[TV] = Ty;
|
||||
|
||||
propagateClasses(TV->Contexts, Ty, Source);
|
||||
propagateClasses(TV->Contexts, Ty);
|
||||
|
||||
// This is a very specific adjustment that is critical to the
|
||||
// well-functioning of the infer/unify algorithm. When addConstraint() is
|
||||
|
@ -1335,24 +1144,62 @@ namespace bolt {
|
|||
|
||||
}
|
||||
|
||||
bool Checker::unify(Type* A, Type* B, Node* Source) {
|
||||
void Checker::unifyError() {
|
||||
DE.add<UnificationErrorDiagnostic>(
|
||||
simplify(OrigLeft),
|
||||
simplify(OrigRight),
|
||||
LeftPath,
|
||||
RightPath,
|
||||
Source
|
||||
);
|
||||
}
|
||||
|
||||
auto find = [&](auto OrigTy) {
|
||||
auto Ty = OrigTy;
|
||||
if (llvm::isa<TVar>(Ty)) {
|
||||
auto TV = static_cast<TVar*>(Ty);
|
||||
do {
|
||||
auto Match = Solution.find(static_cast<TVar*>(Ty));
|
||||
if (Match == Solution.end()) {
|
||||
break;
|
||||
}
|
||||
Ty = Match->second;
|
||||
} while (Ty->getKind() == TypeKind::Var);
|
||||
// FIXME does this actually improove performance?
|
||||
Solution[TV] = Ty;
|
||||
class ArrowCursor {
|
||||
|
||||
std::stack<std::tuple<TArrow*, bool>> Stack;
|
||||
TypePath& Path;
|
||||
std::size_t I;
|
||||
|
||||
public:
|
||||
|
||||
ArrowCursor(TArrow* Arr, TypePath& Path):
|
||||
Path(Path) {
|
||||
Stack.push({ Arr, true });
|
||||
Path.push_back(Arr->getStartIndex());
|
||||
}
|
||||
return Ty;
|
||||
};
|
||||
|
||||
Type* next() {
|
||||
while (!Stack.empty()) {
|
||||
auto& [Arrow, First] = Stack.top();
|
||||
auto& Index = Path.back();
|
||||
if (!First) {
|
||||
Index.advance(Arrow);
|
||||
} else {
|
||||
First = false;
|
||||
}
|
||||
Type* Ty;
|
||||
if (Index == Arrow->getEndIndex()) {
|
||||
Path.pop_back();
|
||||
Stack.pop();
|
||||
continue;
|
||||
}
|
||||
Ty = Arrow->resolve(Index);
|
||||
if (llvm::isa<TArrow>(Ty)) {
|
||||
auto NewIndex = Arrow->getStartIndex();
|
||||
Stack.push({ static_cast<TArrow*>(Ty), true });
|
||||
Path.push_back(NewIndex);
|
||||
} else {
|
||||
return Ty;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
bool Checker::unify(Type* A, Type* B) {
|
||||
|
||||
A = simplify(A);
|
||||
B = simplify(B);
|
||||
|
@ -1362,6 +1209,7 @@ namespace bolt {
|
|||
auto Var2 = static_cast<TVar*>(B);
|
||||
if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) {
|
||||
if (Var1->Id != Var2->Id) {
|
||||
unifyError();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -1373,38 +1221,47 @@ namespace bolt {
|
|||
From = Var2;
|
||||
} else {
|
||||
// Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification
|
||||
// Either way, Var1 is a good candidate for being unified away
|
||||
// Either way, Var1, being Unification, is a good candidate for being unified away
|
||||
To = Var2;
|
||||
From = Var1;
|
||||
}
|
||||
join(From, To, Source);
|
||||
propagateClasses(From->Contexts, To, Source);
|
||||
join(From, To);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (llvm::isa<TVar>(A)) {
|
||||
|
||||
auto TV = static_cast<TVar*>(A);
|
||||
|
||||
// Rigid type variables can never unify with antything else than what we
|
||||
// have already handled in the previous if-statement, so issue an error.
|
||||
if (TV->getVarKind() == VarKind::Rigid) {
|
||||
unifyError();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Occurs check
|
||||
if (B->hasTypeVar(TV)) {
|
||||
// NOTE Just like GHC, we just display an error message indicating that
|
||||
// A cannot match B, e.g. a cannot match [a]. It looks much better
|
||||
// than obsure references to an occurs check
|
||||
unifyError();
|
||||
return false;
|
||||
}
|
||||
join(TV, B, Source);
|
||||
|
||||
join(TV, B);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (llvm::isa<TVar>(B)) {
|
||||
return unify(B, A, Source);
|
||||
return unify(B, A);
|
||||
}
|
||||
|
||||
if (llvm::isa<TArrow>(A) && llvm::isa<TArrow>(B)) {
|
||||
auto C1 = ArrowCursor(static_cast<TArrow*>(A));
|
||||
auto C2 = ArrowCursor(static_cast<TArrow*>(B));
|
||||
auto C1 = ArrowCursor(static_cast<TArrow*>(A), LeftPath);
|
||||
auto C2 = ArrowCursor(static_cast<TArrow*>(B), RightPath);
|
||||
bool Success = true;
|
||||
for (;;) {
|
||||
auto T1 = C1.next();
|
||||
auto T2 = C2.next();
|
||||
|
@ -1412,13 +1269,15 @@ namespace bolt {
|
|||
break;
|
||||
}
|
||||
if (T1 == nullptr || T2 == nullptr) {
|
||||
return false;
|
||||
unifyError();
|
||||
Success = false;
|
||||
break;
|
||||
}
|
||||
if (!unify(T1, T2, Source)) {
|
||||
return false;
|
||||
if (!unify(T1, T2)) {
|
||||
Success = false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return Success;
|
||||
/* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */
|
||||
/* return false; */
|
||||
/* } */
|
||||
|
@ -1434,26 +1293,31 @@ namespace bolt {
|
|||
if (llvm::isa<TArrow>(A)) {
|
||||
auto Arr = static_cast<TArrow*>(A);
|
||||
if (Arr->ParamTypes.empty()) {
|
||||
return unify(Arr->ReturnType, B, Source);
|
||||
return unify(Arr->ReturnType, B);
|
||||
}
|
||||
}
|
||||
|
||||
if (llvm::isa<TArrow>(B)) {
|
||||
return unify(B, A, Source);
|
||||
return unify(B, A);
|
||||
}
|
||||
|
||||
if (llvm::isa<TTuple>(A) && llvm::isa<TTuple>(B)) {
|
||||
auto Tuple1 = static_cast<TTuple*>(A);
|
||||
auto Tuple2 = static_cast<TTuple*>(B);
|
||||
if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) {
|
||||
unifyError();
|
||||
return false;
|
||||
}
|
||||
auto Count = Tuple1->ElementTypes.size();
|
||||
bool Success = true;
|
||||
for (size_t I = 0; I < Count; I++) {
|
||||
if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], Source)) {
|
||||
LeftPath.push_back(TypeIndex::forTupleElement(I));
|
||||
RightPath.push_back(TypeIndex::forTupleElement(I));
|
||||
if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I])) {
|
||||
Success = false;
|
||||
}
|
||||
LeftPath.pop_back();
|
||||
RightPath.pop_back();
|
||||
}
|
||||
return Success;
|
||||
}
|
||||
|
@ -1468,18 +1332,25 @@ namespace bolt {
|
|||
auto Con1 = static_cast<TCon*>(A);
|
||||
auto Con2 = static_cast<TCon*>(B);
|
||||
if (Con1->Id != Con2->Id) {
|
||||
unifyError();
|
||||
return false;
|
||||
}
|
||||
ZEN_ASSERT(Con1->Args.size() == Con2->Args.size());
|
||||
auto Count = Con1->Args.size();
|
||||
bool Success = true;
|
||||
for (std::size_t I = 0; I < Count; I++) {
|
||||
if (!unify(Con1->Args[I], Con2->Args[I], Source)) {
|
||||
return false;
|
||||
LeftPath.push_back(TypeIndex::forConArg(I));
|
||||
RightPath.push_back(TypeIndex::forConArg(I));
|
||||
if (!unify(Con1->Args[I], Con2->Args[I])) {
|
||||
Success = false;
|
||||
}
|
||||
LeftPath.pop_back();
|
||||
RightPath.pop_back();
|
||||
}
|
||||
return true;
|
||||
return Success;
|
||||
}
|
||||
|
||||
unifyError();
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -438,14 +438,29 @@ namespace bolt {
|
|||
setBold(true);
|
||||
Out << "error: ";
|
||||
resetStyles();
|
||||
Out << "the types " << ANSI_FG_GREEN << describe(E.Left) << ANSI_RESET
|
||||
<< " and " << ANSI_FG_GREEN << describe(E.Right) << ANSI_RESET << " failed to match\n\n";
|
||||
auto Left = E.Left->resolve(E.LeftPath);
|
||||
auto Right = E.Right->resolve(E.RightPath);
|
||||
Out << "the types " << ANSI_FG_GREEN << describe(Left) << ANSI_RESET
|
||||
<< " and " << ANSI_FG_GREEN << describe(Right) << ANSI_RESET << " failed to match\n\n";
|
||||
if (E.Source) {
|
||||
auto Range = E.Source->getRange();
|
||||
//std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n";
|
||||
writeExcerpt(E.Source->getSourceFile()->getTextFile(), Range, Range, Color::Red);
|
||||
Out << "\n";
|
||||
}
|
||||
if (!E.LeftPath.empty()) {
|
||||
setForegroundColor(Color::Yellow);
|
||||
setBold(true);
|
||||
Out << " info: ";
|
||||
resetStyles();
|
||||
Out << "type " << ANSI_FG_GREEN << describe(Left) << ANSI_RESET << " occurs in the full type " << ANSI_FG_GREEN << describe(E.Left) << ANSI_RESET << "\n\n";
|
||||
}
|
||||
if (!E.RightPath.empty()) {
|
||||
setForegroundColor(Color::Yellow);
|
||||
setBold(true);
|
||||
Out << " info: ";
|
||||
resetStyles();
|
||||
Out << "type " << ANSI_FG_GREEN << describe(Right) << ANSI_RESET << " occurs in the full type " << ANSI_FG_GREEN << describe(E.Right) << ANSI_RESET << "\n\n";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
369
src/Types.cc
Normal file
369
src/Types.cc
Normal file
|
@ -0,0 +1,369 @@
|
|||
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include "zen/config.hpp"
|
||||
#include "zen/range.hpp"
|
||||
|
||||
#include "bolt/Type.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
bool TypeclassSignature::operator<(const TypeclassSignature& Other) const {
|
||||
if (Id < Other.Id) {
|
||||
return true;
|
||||
}
|
||||
ZEN_ASSERT(Params.size() == 1);
|
||||
ZEN_ASSERT(Other.Params.size() == 1);
|
||||
return Params[0]->Id < Other.Params[0]->Id;
|
||||
}
|
||||
|
||||
bool TypeclassSignature::operator==(const TypeclassSignature& Other) const {
|
||||
ZEN_ASSERT(Params.size() == 1);
|
||||
ZEN_ASSERT(Other.Params.size() == 1);
|
||||
return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id;
|
||||
}
|
||||
|
||||
bool TypeIndex::operator==(const TypeIndex& Other) const noexcept {
|
||||
if (Kind != Other.Kind) {
|
||||
return false;
|
||||
}
|
||||
switch (Kind) {
|
||||
case TypeIndexKind::ConArg:
|
||||
case TypeIndexKind::ArrowParamType:
|
||||
case TypeIndexKind::TupleElement:
|
||||
return I == Other.I;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
void TypeIndex::advance(const Type* Ty) {
|
||||
switch (Kind) {
|
||||
case TypeIndexKind::End:
|
||||
break;
|
||||
case TypeIndexKind::ArrowParamType:
|
||||
{
|
||||
auto Arrow = llvm::cast<TArrow>(Ty);
|
||||
if (I+1 < Arrow->ParamTypes.size()) {
|
||||
++I;
|
||||
} else {
|
||||
Kind = TypeIndexKind::ArrowReturnType;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TypeIndexKind::ArrowReturnType:
|
||||
Kind = TypeIndexKind::End;
|
||||
break;
|
||||
case TypeIndexKind::ConArg:
|
||||
{
|
||||
auto Con = llvm::cast<TCon>(Ty);
|
||||
if (I+1 < Con->Args.size()) {
|
||||
++I;
|
||||
} else {
|
||||
Kind = TypeIndexKind::End;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TypeIndexKind::TupleElement:
|
||||
{
|
||||
auto Tuple = llvm::cast<TTuple>(Ty);
|
||||
if (I+1 < Tuple->ElementTypes.size()) {
|
||||
++I;
|
||||
} else {
|
||||
Kind = TypeIndexKind::End;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Type::addTypeVars(TVSet& TVs) {
|
||||
switch (Kind) {
|
||||
case TypeKind::Var:
|
||||
TVs.emplace(static_cast<TVar*>(this));
|
||||
break;
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
Ty->addTypeVars(TVs);
|
||||
}
|
||||
Arrow->ReturnType->addTypeVars(TVs);
|
||||
break;
|
||||
}
|
||||
case TypeKind::Con:
|
||||
{
|
||||
auto Con = static_cast<TCon*>(this);
|
||||
for (auto Ty: Con->Args) {
|
||||
Ty->addTypeVars(TVs);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Index = static_cast<TTupleIndex*>(this);
|
||||
Index->Ty->addTypeVars(TVs);
|
||||
break;
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
for (auto Ty: Tuple->ElementTypes) {
|
||||
Ty->addTypeVars(TVs);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Type::hasTypeVar(const TVar* TV) {
|
||||
switch (Kind) {
|
||||
case TypeKind::Var:
|
||||
return static_cast<TVar*>(this)->Id == TV->Id;
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
if (Ty->hasTypeVar(TV)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return Arrow->ReturnType->hasTypeVar(TV);
|
||||
}
|
||||
case TypeKind::Con:
|
||||
{
|
||||
auto Con = static_cast<TCon*>(this);
|
||||
for (auto Ty: Con->Args) {
|
||||
if (Ty->hasTypeVar(TV)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Index = static_cast<TTupleIndex*>(this);
|
||||
return Index->Ty->hasTypeVar(TV);
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
for (auto Ty: Tuple->ElementTypes) {
|
||||
if (Ty->hasTypeVar(TV)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Type* Type::substitute(const TVSub &Sub) {
|
||||
switch (Kind) {
|
||||
case TypeKind::Var:
|
||||
{
|
||||
auto TV = static_cast<TVar*>(this);
|
||||
auto Match = Sub.find(TV);
|
||||
return Match != Sub.end() ? Match->second->substitute(Sub) : this;
|
||||
}
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewParamTypes;
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
auto NewParamType = Ty->substitute(Sub);
|
||||
if (NewParamType != Ty) {
|
||||
Changed = true;
|
||||
}
|
||||
NewParamTypes.push_back(NewParamType);
|
||||
}
|
||||
auto NewRetTy = Arrow->ReturnType->substitute(Sub) ;
|
||||
if (NewRetTy != Arrow->ReturnType) {
|
||||
Changed = true;
|
||||
}
|
||||
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
|
||||
}
|
||||
case TypeKind::Con:
|
||||
{
|
||||
auto Con = static_cast<TCon*>(this);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewArgs;
|
||||
for (auto Arg: Con->Args) {
|
||||
auto NewArg = Arg->substitute(Sub);
|
||||
if (NewArg != Arg) {
|
||||
Changed = true;
|
||||
}
|
||||
NewArgs.push_back(NewArg);
|
||||
}
|
||||
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
auto Tuple = static_cast<TTupleIndex*>(this);
|
||||
auto NewTy = Tuple->Ty->substitute(Sub);
|
||||
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewElementTypes;
|
||||
for (auto Ty: Tuple->ElementTypes) {
|
||||
auto NewElementType = Ty->substitute(Sub);
|
||||
if (NewElementType != Ty) {
|
||||
Changed = true;
|
||||
}
|
||||
NewElementTypes.push_back(NewElementType);
|
||||
}
|
||||
return Changed ? new TTuple(NewElementTypes) : this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Type* Type::resolve(const TypeIndex& Index) const noexcept {
|
||||
switch (Index.Kind) {
|
||||
case TypeIndexKind::ConArg:
|
||||
return llvm::cast<TCon>(this)->Args[Index.I];
|
||||
case TypeIndexKind::TupleElement:
|
||||
return llvm::cast<TTuple>(this)->ElementTypes[Index.I];
|
||||
case TypeIndexKind::ArrowParamType:
|
||||
return llvm::cast<TArrow>(this)->ParamTypes[Index.I];
|
||||
case TypeIndexKind::ArrowReturnType:
|
||||
return llvm::cast<TArrow>(this)->ReturnType;
|
||||
case TypeIndexKind::End:
|
||||
ZEN_UNREACHABLE
|
||||
}
|
||||
ZEN_UNREACHABLE
|
||||
}
|
||||
|
||||
bool Type::operator==(const Type& Other) const noexcept {
|
||||
switch (Kind) {
|
||||
case TypeKind::Var:
|
||||
if (Other.Kind != TypeKind::Var) {
|
||||
return false;
|
||||
}
|
||||
return static_cast<const TVar*>(this)->Id == static_cast<const TVar&>(Other).Id;
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
if (Other.Kind != TypeKind::Tuple) {
|
||||
return false;
|
||||
}
|
||||
auto A = static_cast<const TTuple&>(*this);
|
||||
auto B = static_cast<const TTuple&>(Other);
|
||||
if (A.ElementTypes.size() != B.ElementTypes.size()) {
|
||||
return false;
|
||||
}
|
||||
for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) {
|
||||
if (*T1 != *T2) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case TypeKind::TupleIndex:
|
||||
{
|
||||
if (Other.Kind != TypeKind::TupleIndex) {
|
||||
return false;
|
||||
}
|
||||
auto A = static_cast<const TTupleIndex&>(*this);
|
||||
auto B = static_cast<const TTupleIndex&>(Other);
|
||||
return A.I == B.I && *A.Ty == *B.Ty;
|
||||
}
|
||||
case TypeKind::Con:
|
||||
{
|
||||
if (Other.Kind != TypeKind::Con) {
|
||||
return false;
|
||||
}
|
||||
auto A = static_cast<const TCon&>(*this);
|
||||
auto B = static_cast<const TCon&>(Other);
|
||||
if (A.Id != B.Id) {
|
||||
return false;
|
||||
}
|
||||
if (A.Args.size() != B.Args.size()) {
|
||||
return false;
|
||||
}
|
||||
for (auto [T1, T2]: zen::zip(A.Args, B.Args)) {
|
||||
if (*T1 != *T2) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
// FIXME Do we really need to 'curry' this type?
|
||||
if (Other.Kind != TypeKind::Arrow) {
|
||||
return false;
|
||||
}
|
||||
auto A = static_cast<const TArrow&>(*this);
|
||||
auto B = static_cast<const TArrow&>(Other);
|
||||
/* ArrowCursor C1 { &A }; */
|
||||
/* ArrowCursor C2 { &B }; */
|
||||
/* for (;;) { */
|
||||
/* auto T1 = C1.next(); */
|
||||
/* auto T2 = C2.next(); */
|
||||
/* if (T1 == nullptr && T2 == nullptr) { */
|
||||
/* break; */
|
||||
/* } */
|
||||
/* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */
|
||||
/* return false; */
|
||||
/* } */
|
||||
/* } */
|
||||
if (A.ParamTypes.size() != B.ParamTypes.size()) {
|
||||
return false;
|
||||
}
|
||||
for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) {
|
||||
if (*T1 != *T2) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return A.ReturnType != B.ReturnType;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TypeIterator Type::begin() {
|
||||
return TypeIterator { this, getStartIndex() };
|
||||
}
|
||||
|
||||
TypeIterator Type::end() {
|
||||
return TypeIterator { this, getEndIndex() };
|
||||
}
|
||||
|
||||
TypeIndex Type::getStartIndex() {
|
||||
switch (Kind) {
|
||||
case TypeKind::Con:
|
||||
{
|
||||
auto Con = static_cast<TCon*>(this);
|
||||
if (Con->Args.empty()) {
|
||||
return TypeIndex(TypeIndexKind::End);
|
||||
}
|
||||
return TypeIndex::forConArg(0);
|
||||
}
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
if (Arrow->ParamTypes.empty()) {
|
||||
return TypeIndex::forArrowReturnType();
|
||||
}
|
||||
return TypeIndex::forArrowParamType(0);
|
||||
}
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
if (Tuple->ElementTypes.empty()) {
|
||||
return TypeIndex(TypeIndexKind::End);
|
||||
}
|
||||
return TypeIndex::forTupleElement(0);
|
||||
}
|
||||
default:
|
||||
return TypeIndex(TypeIndexKind::End);
|
||||
}
|
||||
}
|
||||
|
||||
TypeIndex Type::getEndIndex() {
|
||||
return TypeIndex(TypeIndexKind::End);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue