Simplify TArrow type as described in issue #42
This commit is contained in:
parent
eef23feb1c
commit
4294063921
4 changed files with 68 additions and 114 deletions
|
@ -81,8 +81,8 @@ namespace bolt {
|
|||
return { TypeIndexKind::FieldRestType };
|
||||
}
|
||||
|
||||
static TypeIndex forArrowParamType(std::size_t I) {
|
||||
return { TypeIndexKind::ArrowParamType, I };
|
||||
static TypeIndex forArrowParamType() {
|
||||
return { TypeIndexKind::ArrowParamType };
|
||||
}
|
||||
|
||||
static TypeIndex forArrowReturnType() {
|
||||
|
@ -303,16 +303,25 @@ namespace bolt {
|
|||
class TArrow : public Type {
|
||||
public:
|
||||
|
||||
std::vector<Type*> ParamTypes;
|
||||
Type* ParamType;
|
||||
Type* ReturnType;
|
||||
|
||||
inline TArrow(
|
||||
std::vector<Type*> ParamTypes,
|
||||
Type* ParamType,
|
||||
Type* ReturnType
|
||||
): Type(TypeKind::Arrow),
|
||||
ParamTypes(ParamTypes),
|
||||
ParamType(ParamType),
|
||||
ReturnType(ReturnType) {}
|
||||
|
||||
static Type* build(std::vector<Type*> ParamTypes, Type* ReturnType) {
|
||||
Type* Curr = ReturnType;
|
||||
for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) {
|
||||
Curr = new TArrow(*Iter, Curr);
|
||||
}
|
||||
return Curr;
|
||||
}
|
||||
|
||||
|
||||
static bool classof(const Type* Ty) {
|
||||
return Ty->getKind() == TypeKind::Arrow;
|
||||
}
|
||||
|
@ -474,9 +483,7 @@ namespace bolt {
|
|||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<C<TArrow>*>(Ty);
|
||||
for (auto I = 0; I < Arrow->ParamTypes.size(); ++I) {
|
||||
visit(Arrow->ParamTypes[I]);
|
||||
}
|
||||
visit(Arrow->ParamType);
|
||||
visit(Arrow->ReturnType);
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -313,7 +313,7 @@ namespace bolt {
|
|||
for (auto Element: TupleMember->Elements) {
|
||||
ParamTypes.push_back(inferTypeExpression(Element));
|
||||
}
|
||||
Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(ParamTypes, RetTy)));
|
||||
Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, TArrow::build(ParamTypes, RetTy)));
|
||||
break;
|
||||
}
|
||||
case NodeKind::RecordVariantDeclarationMember:
|
||||
|
@ -358,7 +358,7 @@ namespace bolt {
|
|||
for (auto TV: Vars) {
|
||||
RetTy = new TApp(RetTy, TV);
|
||||
}
|
||||
Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
|
||||
Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(FieldsTy, RetTy)));
|
||||
popContext();
|
||||
|
||||
break;
|
||||
|
@ -570,7 +570,7 @@ namespace bolt {
|
|||
RetType = createTypeVar();
|
||||
}
|
||||
|
||||
makeEqual(Decl->Ty, new TArrow(ParamTypes, RetType), Decl);
|
||||
makeEqual(Decl->Ty, TArrow::build(ParamTypes, RetType), Decl);
|
||||
|
||||
}
|
||||
|
||||
|
@ -849,7 +849,7 @@ namespace bolt {
|
|||
ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly));
|
||||
}
|
||||
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly);
|
||||
auto Ty = new TArrow(ParamTypes, ReturnType);
|
||||
auto Ty = TArrow::build(ParamTypes, ReturnType);
|
||||
N->setType(Ty);
|
||||
return Ty;
|
||||
}
|
||||
|
@ -910,7 +910,7 @@ namespace bolt {
|
|||
setContext(OldCtx);
|
||||
}
|
||||
if (!Match->Value) {
|
||||
Ty = new TArrow({ ValTy }, Ty);
|
||||
Ty = new TArrow(ValTy, Ty);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -962,7 +962,7 @@ namespace bolt {
|
|||
for (auto Arg: Call->Args) {
|
||||
ArgTypes.push_back(inferExpression(Arg));
|
||||
}
|
||||
makeEqual(OpTy, new TArrow(ArgTypes, Ty), X);
|
||||
makeEqual(OpTy, TArrow::build(ArgTypes, Ty), X);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -979,7 +979,7 @@ namespace bolt {
|
|||
std::vector<Type*> ArgTys;
|
||||
ArgTys.push_back(inferExpression(Infix->LHS));
|
||||
ArgTys.push_back(inferExpression(Infix->RHS));
|
||||
makeEqual(new TArrow(ArgTys, Ty), OpTy, X);
|
||||
makeEqual(TArrow::build(ArgTys, Ty), OpTy, X);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -1066,7 +1066,7 @@ namespace bolt {
|
|||
}
|
||||
auto Ty = instantiate(Scm, P);
|
||||
auto RetTy = createTypeVar();
|
||||
makeEqual(Ty, new TArrow(ParamTypes, RetTy), P);
|
||||
makeEqual(Ty, TArrow::build(ParamTypes, RetTy), P);
|
||||
return RetTy;
|
||||
}
|
||||
|
||||
|
@ -1181,11 +1181,11 @@ namespace bolt {
|
|||
addBinding("True", new Forall(BoolType));
|
||||
addBinding("False", new Forall(BoolType));
|
||||
auto A = createTypeVar();
|
||||
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, new TArrow({ A, A }, BoolType)));
|
||||
addBinding("+", new Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("-", new Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("*", new Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("/", new Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, TArrow::build({ A, A }, BoolType)));
|
||||
addBinding("+", new Forall(TArrow::build({ IntType, IntType }, IntType)));
|
||||
addBinding("-", new Forall(TArrow::build({ IntType, IntType }, IntType)));
|
||||
addBinding("*", new Forall(TArrow::build({ IntType, IntType }, IntType)));
|
||||
addBinding("/", new Forall(TArrow::build({ IntType, IntType }, IntType)));
|
||||
populate(SF);
|
||||
forwardDeclare(SF);
|
||||
auto SCCs = RefGraph.strongconnect();
|
||||
|
@ -1593,62 +1593,47 @@ namespace bolt {
|
|||
}
|
||||
|
||||
if (llvm::isa<TArrow>(A) && llvm::isa<TArrow>(B)) {
|
||||
auto C1 = ArrowCursor(static_cast<TArrow*>(A), DidSwap ? RightPath : LeftPath);
|
||||
auto C2 = ArrowCursor(static_cast<TArrow*>(B), DidSwap ? LeftPath : RightPath);
|
||||
auto Arrow1 = static_cast<TArrow*>(A);
|
||||
auto Arrow2 = static_cast<TArrow*>(B);
|
||||
bool Success = true;
|
||||
for (;;) {
|
||||
auto T1 = C1.next();
|
||||
auto T2 = C2.next();
|
||||
if (T1 == nullptr && T2 == nullptr) {
|
||||
break;
|
||||
}
|
||||
if (T1 == nullptr || T2 == nullptr) {
|
||||
unifyError();
|
||||
Success = false;
|
||||
break;
|
||||
}
|
||||
if (!unify(T1, T2, DidSwap)) {
|
||||
Success = false;
|
||||
}
|
||||
LeftPath.push_back(TypeIndex::forArrowParamType());
|
||||
RightPath.push_back(TypeIndex::forArrowParamType());
|
||||
if (!unify(Arrow1->ParamType, Arrow2->ParamType, DidSwap)) {
|
||||
Success = false;
|
||||
}
|
||||
LeftPath.pop_back();
|
||||
RightPath.pop_back();
|
||||
LeftPath.push_back(TypeIndex::forArrowReturnType());
|
||||
RightPath.push_back(TypeIndex::forArrowReturnType());
|
||||
if (!unify(Arrow1->ReturnType, Arrow2->ReturnType, DidSwap)) {
|
||||
Success = false;
|
||||
}
|
||||
LeftPath.pop_back();
|
||||
RightPath.pop_back();
|
||||
return Success;
|
||||
/* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */
|
||||
/* return false; */
|
||||
/* } */
|
||||
/* auto Count = Arr1->ParamTypes.size(); */
|
||||
/* for (std::size_t I = 0; I < Count; I++) { */
|
||||
/* if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) { */
|
||||
/* return false; */
|
||||
/* } */
|
||||
/* } */
|
||||
/* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */
|
||||
}
|
||||
|
||||
if (llvm::isa<TApp>(A) && llvm::isa<TApp>(B)) {
|
||||
auto App1 = static_cast<TApp*>(A);
|
||||
auto App2 = static_cast<TApp*>(B);
|
||||
bool Success = true;
|
||||
LeftPath.push_back(TypeIndex::forAppOpType());
|
||||
RightPath.push_back(TypeIndex::forAppOpType());
|
||||
if (!unify(App1->Op, App2->Op, DidSwap)) {
|
||||
Success = false;
|
||||
}
|
||||
LeftPath.pop_back();
|
||||
RightPath.pop_back();
|
||||
LeftPath.push_back(TypeIndex::forAppArgType());
|
||||
RightPath.push_back(TypeIndex::forAppArgType());
|
||||
if (!unify(App1->Arg, App2->Arg, DidSwap)) {
|
||||
Success = false;
|
||||
}
|
||||
LeftPath.pop_back();
|
||||
RightPath.pop_back();
|
||||
return Success;
|
||||
}
|
||||
|
||||
if (llvm::isa<TArrow>(B)) {
|
||||
swap();
|
||||
}
|
||||
|
||||
if (llvm::isa<TArrow>(A)) {
|
||||
auto Arr = static_cast<TArrow*>(A);
|
||||
if (Arr->ParamTypes.empty()) {
|
||||
auto Success = unify(Arr->ReturnType, B, DidSwap);
|
||||
return Success;
|
||||
}
|
||||
}
|
||||
|
||||
if (llvm::isa<TTuple>(A) && llvm::isa<TTuple>(B)) {
|
||||
auto Tuple1 = static_cast<TTuple*>(A);
|
||||
auto Tuple2 = static_cast<TTuple*>(B);
|
||||
|
|
|
@ -165,14 +165,7 @@ namespace bolt {
|
|||
{
|
||||
auto Y = static_cast<const TArrow*>(Ty);
|
||||
std::ostringstream Out;
|
||||
Out << "(";
|
||||
bool First = true;
|
||||
for (auto PT: Y->ParamTypes) {
|
||||
if (First) First = false;
|
||||
else Out << ", ";
|
||||
Out << describe(PT);
|
||||
}
|
||||
Out << ") -> " << describe(Y->ReturnType);
|
||||
Out << describe(Y->ParamType) << " -> " << describe(Y->ReturnType);
|
||||
return Out.str();
|
||||
}
|
||||
case TypeKind::Con:
|
||||
|
@ -558,17 +551,10 @@ namespace bolt {
|
|||
}
|
||||
|
||||
void visitArrowType(const TArrow* Ty) override {
|
||||
W.write("(");
|
||||
bool First = true;
|
||||
std::size_t I = 0;
|
||||
for (auto PT: Ty->ParamTypes) {
|
||||
if (First) First = false;
|
||||
else W.write(", ");
|
||||
Path.push_back(TypeIndex::forArrowParamType(I++));
|
||||
visit(PT);
|
||||
Path.pop_back();
|
||||
}
|
||||
W.write(") -> ");
|
||||
Path.push_back(TypeIndex::forArrowParamType());
|
||||
visit(Ty->ParamType);
|
||||
Path.pop_back();
|
||||
W.write(" -> ");
|
||||
Path.push_back(TypeIndex::forArrowReturnType());
|
||||
visit(Ty->ReturnType);
|
||||
Path.pop_back();
|
||||
|
|
48
src/Types.cc
48
src/Types.cc
|
@ -44,15 +44,11 @@ namespace bolt {
|
|||
Kind = TypeIndexKind::AppArgType;
|
||||
break;
|
||||
case TypeIndexKind::ArrowParamType:
|
||||
{
|
||||
auto Arrow = llvm::cast<TArrow>(Ty);
|
||||
if (I+1 < Arrow->ParamTypes.size()) {
|
||||
++I;
|
||||
} else {
|
||||
Kind = TypeIndexKind::ArrowReturnType;
|
||||
}
|
||||
Kind = TypeIndexKind::ArrowReturnType;
|
||||
break;
|
||||
case TypeIndexKind::ArrowReturnType:
|
||||
Kind = TypeIndexKind::End;
|
||||
break;
|
||||
}
|
||||
case TypeIndexKind::FieldType:
|
||||
Kind = TypeIndexKind::FieldRestType;
|
||||
break;
|
||||
|
@ -60,9 +56,6 @@ namespace bolt {
|
|||
case TypeIndexKind::TupleIndexType:
|
||||
case TypeIndexKind::PresentType:
|
||||
case TypeIndexKind::AppArgType:
|
||||
case TypeIndexKind::ArrowReturnType:
|
||||
Kind = TypeIndexKind::End;
|
||||
break;
|
||||
case TypeIndexKind::TupleElement:
|
||||
{
|
||||
auto Tuple = llvm::cast<TTuple>(Ty);
|
||||
|
@ -88,19 +81,15 @@ namespace bolt {
|
|||
{
|
||||
auto Arrow = static_cast<TArrow*>(Ty2);
|
||||
bool Changed = false;
|
||||
std::vector<Type*> NewParamTypes;
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
auto NewParamType = Ty->rewrite(Fn);
|
||||
if (NewParamType != Ty) {
|
||||
Changed = true;
|
||||
}
|
||||
NewParamTypes.push_back(NewParamType);
|
||||
Type* NewParamType = Arrow->ParamType->rewrite(Fn);
|
||||
if (NewParamType != Arrow->ParamType) {
|
||||
Changed = true;
|
||||
}
|
||||
auto NewRetTy = Arrow->ReturnType->rewrite(Fn);
|
||||
if (NewRetTy != Arrow->ReturnType) {
|
||||
Changed = true;
|
||||
}
|
||||
return Changed ? new TArrow(NewParamTypes, NewRetTy) : Ty2;
|
||||
return Changed ? new TArrow(NewParamType, NewRetTy) : Ty2;
|
||||
}
|
||||
case TypeKind::Con:
|
||||
return Ty2;
|
||||
|
@ -173,9 +162,7 @@ namespace bolt {
|
|||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
Ty->addTypeVars(TVs);
|
||||
}
|
||||
Arrow->ParamType->addTypeVars(TVs);
|
||||
Arrow->ReturnType->addTypeVars(TVs);
|
||||
break;
|
||||
}
|
||||
|
@ -229,12 +216,7 @@ namespace bolt {
|
|||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
for (auto Ty: Arrow->ParamTypes) {
|
||||
if (Ty->hasTypeVar(TV)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return Arrow->ReturnType->hasTypeVar(TV);
|
||||
return Arrow->ParamType->hasTypeVar(TV) || Arrow->ReturnType->hasTypeVar(TV);
|
||||
}
|
||||
case TypeKind::Con:
|
||||
return false;
|
||||
|
@ -308,7 +290,7 @@ namespace bolt {
|
|||
case TypeIndexKind::TupleElement:
|
||||
return llvm::cast<TTuple>(this)->ElementTypes[Index.I];
|
||||
case TypeIndexKind::ArrowParamType:
|
||||
return llvm::cast<TArrow>(this)->ParamTypes[Index.I];
|
||||
return llvm::cast<TArrow>(this)->ParamType;
|
||||
case TypeIndexKind::ArrowReturnType:
|
||||
return llvm::cast<TArrow>(this)->ReturnType;
|
||||
case TypeIndexKind::FieldType:
|
||||
|
@ -445,13 +427,7 @@ namespace bolt {
|
|||
TypeIndex Type::getStartIndex() {
|
||||
switch (Kind) {
|
||||
case TypeKind::Arrow:
|
||||
{
|
||||
auto Arrow = static_cast<TArrow*>(this);
|
||||
if (Arrow->ParamTypes.empty()) {
|
||||
return TypeIndex::forArrowReturnType();
|
||||
}
|
||||
return TypeIndex::forArrowParamType(0);
|
||||
}
|
||||
return TypeIndex::forArrowParamType();
|
||||
case TypeKind::Tuple:
|
||||
{
|
||||
auto Tuple = static_cast<TTuple*>(this);
|
||||
|
|
Loading…
Reference in a new issue