#include "llvm/Support/Casting.h" #include "zen/config.hpp" #include "zen/range.hpp" #include "bolt/Type.hpp" namespace bolt { bool TypeclassSignature::operator<(const TypeclassSignature& Other) const { if (Id < Other.Id) { return true; } ZEN_ASSERT(Params.size() == 1); ZEN_ASSERT(Other.Params.size() == 1); return Params[0]->Id < Other.Params[0]->Id; } bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { ZEN_ASSERT(Params.size() == 1); ZEN_ASSERT(Other.Params.size() == 1); return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id; } bool TypeIndex::operator==(const TypeIndex& Other) const noexcept { if (Kind != Other.Kind) { return false; } switch (Kind) { case TypeIndexKind::ConArg: case TypeIndexKind::ArrowParamType: case TypeIndexKind::TupleElement: return I == Other.I; default: return true; } } void TypeIndex::advance(const Type* Ty) { switch (Kind) { case TypeIndexKind::End: break; case TypeIndexKind::ArrowParamType: { auto Arrow = llvm::cast(Ty); if (I+1 < Arrow->ParamTypes.size()) { ++I; } else { Kind = TypeIndexKind::ArrowReturnType; } break; } case TypeIndexKind::ArrowReturnType: Kind = TypeIndexKind::End; break; case TypeIndexKind::ConArg: { auto Con = llvm::cast(Ty); if (I+1 < Con->Args.size()) { ++I; } else { Kind = TypeIndexKind::End; } break; } case TypeIndexKind::TupleElement: { auto Tuple = llvm::cast(Ty); if (I+1 < Tuple->ElementTypes.size()) { ++I; } else { Kind = TypeIndexKind::End; } break; } } } void Type::addTypeVars(TVSet& TVs) { switch (Kind) { case TypeKind::Var: TVs.emplace(static_cast(this)); break; case TypeKind::Arrow: { auto Arrow = static_cast(this); for (auto Ty: Arrow->ParamTypes) { Ty->addTypeVars(TVs); } Arrow->ReturnType->addTypeVars(TVs); break; } case TypeKind::Con: { auto Con = static_cast(this); for (auto Ty: Con->Args) { Ty->addTypeVars(TVs); } break; } case TypeKind::TupleIndex: { auto Index = static_cast(this); Index->Ty->addTypeVars(TVs); break; } case TypeKind::Tuple: { auto Tuple = static_cast(this); for (auto Ty: Tuple->ElementTypes) { Ty->addTypeVars(TVs); } break; } } } bool Type::hasTypeVar(const TVar* TV) { switch (Kind) { case TypeKind::Var: return static_cast(this)->Id == TV->Id; case TypeKind::Arrow: { auto Arrow = static_cast(this); for (auto Ty: Arrow->ParamTypes) { if (Ty->hasTypeVar(TV)) { return true; } } return Arrow->ReturnType->hasTypeVar(TV); } case TypeKind::Con: { auto Con = static_cast(this); for (auto Ty: Con->Args) { if (Ty->hasTypeVar(TV)) { return true; } } return false; } case TypeKind::TupleIndex: { auto Index = static_cast(this); return Index->Ty->hasTypeVar(TV); } case TypeKind::Tuple: { auto Tuple = static_cast(this); for (auto Ty: Tuple->ElementTypes) { if (Ty->hasTypeVar(TV)) { return true; } } return false; } } } Type* Type::substitute(const TVSub &Sub) { switch (Kind) { case TypeKind::Var: { auto TV = static_cast(this); auto Match = Sub.find(TV); return Match != Sub.end() ? Match->second->substitute(Sub) : this; } case TypeKind::Arrow: { auto Arrow = static_cast(this); bool Changed = false; std::vector NewParamTypes; for (auto Ty: Arrow->ParamTypes) { auto NewParamType = Ty->substitute(Sub); if (NewParamType != Ty) { Changed = true; } NewParamTypes.push_back(NewParamType); } auto NewRetTy = Arrow->ReturnType->substitute(Sub) ; if (NewRetTy != Arrow->ReturnType) { Changed = true; } return Changed ? new TArrow(NewParamTypes, NewRetTy) : this; } case TypeKind::Con: { auto Con = static_cast(this); bool Changed = false; std::vector NewArgs; for (auto Arg: Con->Args) { auto NewArg = Arg->substitute(Sub); if (NewArg != Arg) { Changed = true; } NewArgs.push_back(NewArg); } return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this; } case TypeKind::TupleIndex: { auto Tuple = static_cast(this); auto NewTy = Tuple->Ty->substitute(Sub); return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple; } case TypeKind::Tuple: { auto Tuple = static_cast(this); bool Changed = false; std::vector NewElementTypes; for (auto Ty: Tuple->ElementTypes) { auto NewElementType = Ty->substitute(Sub); if (NewElementType != Ty) { Changed = true; } NewElementTypes.push_back(NewElementType); } return Changed ? new TTuple(NewElementTypes) : this; } } } Type* Type::resolve(const TypeIndex& Index) const noexcept { switch (Index.Kind) { case TypeIndexKind::ConArg: return llvm::cast(this)->Args[Index.I]; case TypeIndexKind::TupleElement: return llvm::cast(this)->ElementTypes[Index.I]; case TypeIndexKind::ArrowParamType: return llvm::cast(this)->ParamTypes[Index.I]; case TypeIndexKind::ArrowReturnType: return llvm::cast(this)->ReturnType; case TypeIndexKind::End: ZEN_UNREACHABLE } ZEN_UNREACHABLE } bool Type::operator==(const Type& Other) const noexcept { switch (Kind) { case TypeKind::Var: if (Other.Kind != TypeKind::Var) { return false; } return static_cast(this)->Id == static_cast(Other).Id; case TypeKind::Tuple: { if (Other.Kind != TypeKind::Tuple) { return false; } auto A = static_cast(*this); auto B = static_cast(Other); if (A.ElementTypes.size() != B.ElementTypes.size()) { return false; } for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) { if (*T1 != *T2) { return false; } } return true; } case TypeKind::TupleIndex: { if (Other.Kind != TypeKind::TupleIndex) { return false; } auto A = static_cast(*this); auto B = static_cast(Other); return A.I == B.I && *A.Ty == *B.Ty; } case TypeKind::Con: { if (Other.Kind != TypeKind::Con) { return false; } auto A = static_cast(*this); auto B = static_cast(Other); if (A.Id != B.Id) { return false; } if (A.Args.size() != B.Args.size()) { return false; } for (auto [T1, T2]: zen::zip(A.Args, B.Args)) { if (*T1 != *T2) { return false; } } return true; } case TypeKind::Arrow: { // FIXME Do we really need to 'curry' this type? if (Other.Kind != TypeKind::Arrow) { return false; } auto A = static_cast(*this); auto B = static_cast(Other); /* ArrowCursor C1 { &A }; */ /* ArrowCursor C2 { &B }; */ /* for (;;) { */ /* auto T1 = C1.next(); */ /* auto T2 = C2.next(); */ /* if (T1 == nullptr && T2 == nullptr) { */ /* break; */ /* } */ /* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */ /* return false; */ /* } */ /* } */ if (A.ParamTypes.size() != B.ParamTypes.size()) { return false; } for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) { if (*T1 != *T2) { return false; } } return A.ReturnType != B.ReturnType; } } } TypeIterator Type::begin() { return TypeIterator { this, getStartIndex() }; } TypeIterator Type::end() { return TypeIterator { this, getEndIndex() }; } TypeIndex Type::getStartIndex() { switch (Kind) { case TypeKind::Con: { auto Con = static_cast(this); if (Con->Args.empty()) { return TypeIndex(TypeIndexKind::End); } return TypeIndex::forConArg(0); } case TypeKind::Arrow: { auto Arrow = static_cast(this); if (Arrow->ParamTypes.empty()) { return TypeIndex::forArrowReturnType(); } return TypeIndex::forArrowParamType(0); } case TypeKind::Tuple: { auto Tuple = static_cast(this); if (Tuple->ElementTypes.empty()) { return TypeIndex(TypeIndexKind::End); } return TypeIndex::forTupleElement(0); } default: return TypeIndex(TypeIndexKind::End); } } TypeIndex Type::getEndIndex() { return TypeIndex(TypeIndexKind::End); } }