diff --git a/include/bolt/Common.hpp b/include/bolt/Common.hpp index 158635610..1c360c532 100644 --- a/include/bolt/Common.hpp +++ b/include/bolt/Common.hpp @@ -32,4 +32,21 @@ namespace bolt { }; + template + D* cast(B* base) { + ZEN_ASSERT(D::classof(base)); + return static_cast(base); + } + + template + const D* cast(const B* base) { + ZEN_ASSERT(D::classof(base)); + return static_cast(base); + } + + template + bool isa(const T* value) { + return D::classof(value); + } + } diff --git a/src/Checker.cc b/src/Checker.cc index 3f143a81b..61a2f196d 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -4,8 +4,6 @@ #include #include -#include "llvm/Support/Casting.h" - #include "bolt/Type.hpp" #include "zen/config.hpp" #include "zen/range.hpp" @@ -484,8 +482,8 @@ namespace bolt { if (Let->isInstance()) { auto Instance = static_cast(Let->Parent); - auto Class = llvm::cast(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class)); - auto SigLet = llvm::cast(Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var)); + auto Class = cast(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class)); + auto SigLet = cast(Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var)); auto Params = addClassVars(Class, false); @@ -498,7 +496,7 @@ namespace bolt { // TVSub Sub; // for (auto TE: Class->TypeVars) { // auto TV = createTypeVar(); - // Sub.emplace(llvm::cast(TE->getType()), TV); + // Sub.emplace(cast(TE->getType()), TV); // Params.push_back(TV); // } @@ -1283,9 +1281,9 @@ namespace bolt { } bool assignableTo(Type* A, Type* B) { - if (llvm::isa(A) && llvm::isa(B)) { - auto Con1 = llvm::cast(A); - auto Con2 = llvm::cast(B); + if (isa(A) && isa(B)) { + auto Con1 = cast(A); + auto Con2 = cast(B); if (Con1->Id != Con2-> Id) { return false; } @@ -1325,7 +1323,7 @@ namespace bolt { continue; } Ty = Arrow->resolve(Index); - if (llvm::isa(Ty)) { + if (isa(Ty)) { auto NewIndex = Arrow->getStartIndex(); Stack.push({ static_cast(Ty), true }); Path.push_back(NewIndex); @@ -1412,8 +1410,8 @@ namespace bolt { } void propagateClasses(std::unordered_set& Classes, Type* Ty) { - if (llvm::isa(Ty)) { - auto TV = llvm::cast(Ty); + if (isa(Ty)) { + auto TV = cast(Ty); for (auto Class: Classes) { TV->Contexts.emplace(Class); } @@ -1425,7 +1423,7 @@ namespace bolt { } } } - } else if (llvm::isa(Ty) || llvm::isa(Ty)) { + } else if (isa(Ty) || isa(Ty)) { auto Sig = getTypeSig(Ty); for (auto Class: Classes) { propagateClassTycon(Class, Sig); @@ -1482,14 +1480,14 @@ namespace bolt { }; bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) { - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { return true; } - if (llvm::isa(B)) { + if (isa(B)) { std::swap(A, B); DidSwap = !DidSwap; } - if (llvm::isa(A)) { + if (isa(A)) { auto Present = static_cast(B); C.DE.add(CurrentFieldName, C.simplifyType(getLeft()), LeftPath, getSource()); return false; @@ -1551,7 +1549,7 @@ namespace bolt { DidSwap = !DidSwap; }; - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { auto Var1 = static_cast(A); auto Var2 = static_cast(B); if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) { @@ -1578,11 +1576,11 @@ namespace bolt { return true; } - if (llvm::isa(B)) { + if (isa(B)) { swap(); } - if (llvm::isa(A)) { + if (isa(A)) { auto TV = static_cast(A); @@ -1607,7 +1605,7 @@ namespace bolt { return true; } - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { auto Arrow1 = static_cast(A); auto Arrow2 = static_cast(B); bool Success = true; @@ -1628,7 +1626,7 @@ namespace bolt { return Success; } - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { auto App1 = static_cast(A); auto App2 = static_cast(B); bool Success = true; @@ -1649,7 +1647,7 @@ namespace bolt { return Success; } - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { auto Tuple1 = static_cast(A); auto Tuple2 = static_cast(B); if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) { @@ -1670,20 +1668,20 @@ namespace bolt { return Success; } - if (llvm::isa(A) || llvm::isa(B)) { + if (isa(A) || isa(B)) { // Type(s) could not be simplified at the beginning of this function, // so we have to re-visit the constraint when there is more information. C.Queue.push_back(Constraint); return true; } - // if (llvm::isa(A) && llvm::isa(B)) { + // if (isa(A) && isa(B)) { // auto Index1 = static_cast(A); // auto Index2 = static_cast(B); // return unify(Index1->Ty, Index2->Ty, Source); // } - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { auto Con1 = static_cast(A); auto Con2 = static_cast(B); if (Con1->Id != Con2->Id) { @@ -1693,11 +1691,11 @@ namespace bolt { return true; } - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { return true; } - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { auto Field1 = static_cast(A); auto Field2 = static_cast(B); bool Success = true; @@ -1733,11 +1731,11 @@ namespace bolt { return Success; } - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { swap(); } - if (llvm::isa(A) && llvm::isa(B)) { + if (isa(A) && isa(B)) { auto Field = static_cast(A); bool Success = true; pushLeft(TypeIndex::forFieldType()); diff --git a/src/Parser.cc b/src/Parser.cc index 98e6279c2..776daf020 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -1,11 +1,9 @@ // TODO check for memory leaks everywhere a nullptr is returned -#include #include -#include "llvm/Support/Casting.h" - +#include "bolt/Common.hpp" #include "bolt/CST.hpp" #include "bolt/Scanner.hpp" #include "bolt/Parser.hpp" @@ -268,7 +266,7 @@ finish: TypeExpression* Parser::parseQualifiedTypeExpression() { bool HasConstraints = false; auto T0 = Tokens.peek(); - if (llvm::isa(T0)) { + if (isa(T0)) { std::size_t I = 1; for (;;) { auto T0 = Tokens.peek(I++); @@ -363,7 +361,7 @@ after_constraints: RParen* RParen; for (;;) { auto T1 = Tokens.peek(); - if (llvm::isa(T1)) { + if (isa(T1)) { Tokens.get(); RParen = static_cast(T1); break; @@ -499,7 +497,7 @@ after_tuple_element: auto T1 = Tokens.peek(); Expression* Value; BlockStart* BlockStart; - if (llvm::isa(T1)) { + if (isa(T1)) { Value = nullptr; BlockStart = static_cast(T1); Tokens.get(); @@ -519,7 +517,7 @@ after_tuple_element: std::vector Cases; for (;;) { auto T2 = Tokens.peek(); - if (llvm::isa(T2)) { + if (isa(T2)) { Tokens.get()->unref(); break; } @@ -627,7 +625,7 @@ after_tuple_element: for (;;) { auto T1 = Tokens.peek(0); auto T2 = Tokens.peek(1); - if (!llvm::isa(T1) || !llvm::isa(T2)) { + if (!isa(T1) || !isa(T2)) { break; } Tokens.get(); @@ -635,7 +633,7 @@ after_tuple_element: ModulePath.push_back(std::make_tuple(static_cast(T1), static_cast(T2))); } auto T3 = Tokens.get(); - if (!llvm::isa(T3)) { + if (!isa(T3)) { for (auto [Name, Dot]: ModulePath) { Name->unref(); Dot->unref(); @@ -652,7 +650,7 @@ after_tuple_element: auto LParen = static_cast(T0); RParen* RParen; auto T1 = Tokens.peek(); - if (llvm::isa(T1)) { + if (isa(T1)) { Tokens.get(); RParen = static_cast(T1); goto after_tuple_elements; @@ -731,7 +729,7 @@ after_tuple_elements: for (;;) { auto T1 = Tokens.peek(0); auto T2 = Tokens.peek(1); - if (!llvm::isa(T1)) { + if (!isa(T1)) { break; } switch (T2->getKind()) { diff --git a/src/Scanner.cc b/src/Scanner.cc index 6808035dc..cc3989a6a 100644 --- a/src/Scanner.cc +++ b/src/Scanner.cc @@ -3,8 +3,7 @@ #include "zen/config.hpp" -#include "llvm/Support/Casting.h" - +#include "bolt/Common.hpp" #include "bolt/Text.hpp" #include "bolt/Integer.hpp" #include "bolt/CST.hpp" @@ -475,7 +474,7 @@ after_string_contents: Locations.pop(); return new LineFoldEnd(T0->getStartLoc()); } - if (llvm::isa(T0)) { + if (isa(T0)) { auto T1 = Tokens.peek(1); if (T1->getStartLine() > T0->getEndLine()) { Tokens.get(); diff --git a/src/Types.cc b/src/Types.cc index a7d2ac52b..feffac871 100644 --- a/src/Types.cc +++ b/src/Types.cc @@ -1,9 +1,8 @@ -#include "llvm/Support/Casting.h" - #include "zen/config.hpp" #include "zen/range.hpp" +#include "bolt/Common.hpp" #include "bolt/Type.hpp" namespace bolt { @@ -58,7 +57,7 @@ namespace bolt { case TypeIndexKind::AppArgType: case TypeIndexKind::TupleElement: { - auto Tuple = llvm::cast(Ty); + auto Tuple = cast(Ty); if (I+1 < Tuple->ElementTypes.size()) { ++I; } else { @@ -271,7 +270,7 @@ namespace bolt { Type* Type::substitute(const TVSub &Sub) { return rewrite([&](auto Ty) { - if (llvm::isa(Ty)) { + if (isa(Ty)) { auto TV = static_cast(Ty); auto Match = Sub.find(TV); return Match != Sub.end() ? Match->second->substitute(Sub) : Ty; @@ -283,23 +282,23 @@ namespace bolt { Type* Type::resolve(const TypeIndex& Index) const noexcept { switch (Index.Kind) { case TypeIndexKind::PresentType: - return llvm::cast(this)->Ty; + return cast(this)->Ty; case TypeIndexKind::AppOpType: - return llvm::cast(this)->Op; + return cast(this)->Op; case TypeIndexKind::AppArgType: - return llvm::cast(this)->Arg; + return cast(this)->Arg; case TypeIndexKind::TupleIndexType: - return llvm::cast(this)->Ty; + return cast(this)->Ty; case TypeIndexKind::TupleElement: - return llvm::cast(this)->ElementTypes[Index.I]; + return cast(this)->ElementTypes[Index.I]; case TypeIndexKind::ArrowParamType: - return llvm::cast(this)->ParamType; + return cast(this)->ParamType; case TypeIndexKind::ArrowReturnType: - return llvm::cast(this)->ReturnType; + return cast(this)->ReturnType; case TypeIndexKind::FieldType: - return llvm::cast(this)->Ty; + return cast(this)->Ty; case TypeIndexKind::FieldRestType: - return llvm::cast(this)->RestTy; + return cast(this)->RestTy; case TypeIndexKind::End: ZEN_UNREACHABLE }