Simplify TArrow type as described in issue #42

This commit is contained in:
Sam Vervaeck 2023-06-03 11:45:14 +02:00
parent eef23feb1c
commit 4294063921
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
4 changed files with 68 additions and 114 deletions

View file

@ -81,8 +81,8 @@ namespace bolt {
return { TypeIndexKind::FieldRestType };
}
static TypeIndex forArrowParamType(std::size_t I) {
return { TypeIndexKind::ArrowParamType, I };
static TypeIndex forArrowParamType() {
return { TypeIndexKind::ArrowParamType };
}
static TypeIndex forArrowReturnType() {
@ -303,16 +303,25 @@ namespace bolt {
class TArrow : public Type {
public:
std::vector<Type*> ParamTypes;
Type* ParamType;
Type* ReturnType;
inline TArrow(
std::vector<Type*> ParamTypes,
Type* ParamType,
Type* ReturnType
): Type(TypeKind::Arrow),
ParamTypes(ParamTypes),
ParamType(ParamType),
ReturnType(ReturnType) {}
static Type* build(std::vector<Type*> ParamTypes, Type* ReturnType) {
Type* Curr = ReturnType;
for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) {
Curr = new TArrow(*Iter, Curr);
}
return Curr;
}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Arrow;
}
@ -474,9 +483,7 @@ namespace bolt {
case TypeKind::Arrow:
{
auto Arrow = static_cast<C<TArrow>*>(Ty);
for (auto I = 0; I < Arrow->ParamTypes.size(); ++I) {
visit(Arrow->ParamTypes[I]);
}
visit(Arrow->ParamType);
visit(Arrow->ReturnType);
break;
}

View file

@ -313,7 +313,7 @@ namespace bolt {
for (auto Element: TupleMember->Elements) {
ParamTypes.push_back(inferTypeExpression(Element));
}
Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(ParamTypes, RetTy)));
Decl->Ctx->Parent->Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, TArrow::build(ParamTypes, RetTy)));
break;
}
case NodeKind::RecordVariantDeclarationMember:
@ -358,7 +358,7 @@ namespace bolt {
for (auto TV: Vars) {
RetTy = new TApp(RetTy, TV);
}
Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow(FieldsTy, RetTy)));
popContext();
break;
@ -570,7 +570,7 @@ namespace bolt {
RetType = createTypeVar();
}
makeEqual(Decl->Ty, new TArrow(ParamTypes, RetType), Decl);
makeEqual(Decl->Ty, TArrow::build(ParamTypes, RetType), Decl);
}
@ -849,7 +849,7 @@ namespace bolt {
ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly));
}
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly);
auto Ty = new TArrow(ParamTypes, ReturnType);
auto Ty = TArrow::build(ParamTypes, ReturnType);
N->setType(Ty);
return Ty;
}
@ -910,7 +910,7 @@ namespace bolt {
setContext(OldCtx);
}
if (!Match->Value) {
Ty = new TArrow({ ValTy }, Ty);
Ty = new TArrow(ValTy, Ty);
}
break;
}
@ -962,7 +962,7 @@ namespace bolt {
for (auto Arg: Call->Args) {
ArgTypes.push_back(inferExpression(Arg));
}
makeEqual(OpTy, new TArrow(ArgTypes, Ty), X);
makeEqual(OpTy, TArrow::build(ArgTypes, Ty), X);
break;
}
@ -979,7 +979,7 @@ namespace bolt {
std::vector<Type*> ArgTys;
ArgTys.push_back(inferExpression(Infix->LHS));
ArgTys.push_back(inferExpression(Infix->RHS));
makeEqual(new TArrow(ArgTys, Ty), OpTy, X);
makeEqual(TArrow::build(ArgTys, Ty), OpTy, X);
break;
}
@ -1066,7 +1066,7 @@ namespace bolt {
}
auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar();
makeEqual(Ty, new TArrow(ParamTypes, RetTy), P);
makeEqual(Ty, TArrow::build(ParamTypes, RetTy), P);
return RetTy;
}
@ -1181,11 +1181,11 @@ namespace bolt {
addBinding("True", new Forall(BoolType));
addBinding("False", new Forall(BoolType));
auto A = createTypeVar();
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, new TArrow({ A, A }, BoolType)));
addBinding("+", new Forall(new TArrow({ IntType, IntType }, IntType)));
addBinding("-", new Forall(new TArrow({ IntType, IntType }, IntType)));
addBinding("*", new Forall(new TArrow({ IntType, IntType }, IntType)));
addBinding("/", new Forall(new TArrow({ IntType, IntType }, IntType)));
addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, TArrow::build({ A, A }, BoolType)));
addBinding("+", new Forall(TArrow::build({ IntType, IntType }, IntType)));
addBinding("-", new Forall(TArrow::build({ IntType, IntType }, IntType)));
addBinding("*", new Forall(TArrow::build({ IntType, IntType }, IntType)));
addBinding("/", new Forall(TArrow::build({ IntType, IntType }, IntType)));
populate(SF);
forwardDeclare(SF);
auto SCCs = RefGraph.strongconnect();
@ -1593,62 +1593,47 @@ namespace bolt {
}
if (llvm::isa<TArrow>(A) && llvm::isa<TArrow>(B)) {
auto C1 = ArrowCursor(static_cast<TArrow*>(A), DidSwap ? RightPath : LeftPath);
auto C2 = ArrowCursor(static_cast<TArrow*>(B), DidSwap ? LeftPath : RightPath);
auto Arrow1 = static_cast<TArrow*>(A);
auto Arrow2 = static_cast<TArrow*>(B);
bool Success = true;
for (;;) {
auto T1 = C1.next();
auto T2 = C2.next();
if (T1 == nullptr && T2 == nullptr) {
break;
}
if (T1 == nullptr || T2 == nullptr) {
unifyError();
Success = false;
break;
}
if (!unify(T1, T2, DidSwap)) {
Success = false;
}
LeftPath.push_back(TypeIndex::forArrowParamType());
RightPath.push_back(TypeIndex::forArrowParamType());
if (!unify(Arrow1->ParamType, Arrow2->ParamType, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
LeftPath.push_back(TypeIndex::forArrowReturnType());
RightPath.push_back(TypeIndex::forArrowReturnType());
if (!unify(Arrow1->ReturnType, Arrow2->ReturnType, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
return Success;
/* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */
/* return false; */
/* } */
/* auto Count = Arr1->ParamTypes.size(); */
/* for (std::size_t I = 0; I < Count; I++) { */
/* if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) { */
/* return false; */
/* } */
/* } */
/* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */
}
if (llvm::isa<TApp>(A) && llvm::isa<TApp>(B)) {
auto App1 = static_cast<TApp*>(A);
auto App2 = static_cast<TApp*>(B);
bool Success = true;
LeftPath.push_back(TypeIndex::forAppOpType());
RightPath.push_back(TypeIndex::forAppOpType());
if (!unify(App1->Op, App2->Op, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
LeftPath.push_back(TypeIndex::forAppArgType());
RightPath.push_back(TypeIndex::forAppArgType());
if (!unify(App1->Arg, App2->Arg, DidSwap)) {
Success = false;
}
LeftPath.pop_back();
RightPath.pop_back();
return Success;
}
if (llvm::isa<TArrow>(B)) {
swap();
}
if (llvm::isa<TArrow>(A)) {
auto Arr = static_cast<TArrow*>(A);
if (Arr->ParamTypes.empty()) {
auto Success = unify(Arr->ReturnType, B, DidSwap);
return Success;
}
}
if (llvm::isa<TTuple>(A) && llvm::isa<TTuple>(B)) {
auto Tuple1 = static_cast<TTuple*>(A);
auto Tuple2 = static_cast<TTuple*>(B);

View file

@ -165,14 +165,7 @@ namespace bolt {
{
auto Y = static_cast<const TArrow*>(Ty);
std::ostringstream Out;
Out << "(";
bool First = true;
for (auto PT: Y->ParamTypes) {
if (First) First = false;
else Out << ", ";
Out << describe(PT);
}
Out << ") -> " << describe(Y->ReturnType);
Out << describe(Y->ParamType) << " -> " << describe(Y->ReturnType);
return Out.str();
}
case TypeKind::Con:
@ -558,17 +551,10 @@ namespace bolt {
}
void visitArrowType(const TArrow* Ty) override {
W.write("(");
bool First = true;
std::size_t I = 0;
for (auto PT: Ty->ParamTypes) {
if (First) First = false;
else W.write(", ");
Path.push_back(TypeIndex::forArrowParamType(I++));
visit(PT);
Path.pop_back();
}
W.write(") -> ");
Path.push_back(TypeIndex::forArrowParamType());
visit(Ty->ParamType);
Path.pop_back();
W.write(" -> ");
Path.push_back(TypeIndex::forArrowReturnType());
visit(Ty->ReturnType);
Path.pop_back();

View file

@ -44,15 +44,11 @@ namespace bolt {
Kind = TypeIndexKind::AppArgType;
break;
case TypeIndexKind::ArrowParamType:
{
auto Arrow = llvm::cast<TArrow>(Ty);
if (I+1 < Arrow->ParamTypes.size()) {
++I;
} else {
Kind = TypeIndexKind::ArrowReturnType;
}
Kind = TypeIndexKind::ArrowReturnType;
break;
case TypeIndexKind::ArrowReturnType:
Kind = TypeIndexKind::End;
break;
}
case TypeIndexKind::FieldType:
Kind = TypeIndexKind::FieldRestType;
break;
@ -60,9 +56,6 @@ namespace bolt {
case TypeIndexKind::TupleIndexType:
case TypeIndexKind::PresentType:
case TypeIndexKind::AppArgType:
case TypeIndexKind::ArrowReturnType:
Kind = TypeIndexKind::End;
break;
case TypeIndexKind::TupleElement:
{
auto Tuple = llvm::cast<TTuple>(Ty);
@ -88,19 +81,15 @@ namespace bolt {
{
auto Arrow = static_cast<TArrow*>(Ty2);
bool Changed = false;
std::vector<Type*> NewParamTypes;
for (auto Ty: Arrow->ParamTypes) {
auto NewParamType = Ty->rewrite(Fn);
if (NewParamType != Ty) {
Changed = true;
}
NewParamTypes.push_back(NewParamType);
Type* NewParamType = Arrow->ParamType->rewrite(Fn);
if (NewParamType != Arrow->ParamType) {
Changed = true;
}
auto NewRetTy = Arrow->ReturnType->rewrite(Fn);
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamTypes, NewRetTy) : Ty2;
return Changed ? new TArrow(NewParamType, NewRetTy) : Ty2;
}
case TypeKind::Con:
return Ty2;
@ -173,9 +162,7 @@ namespace bolt {
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
for (auto Ty: Arrow->ParamTypes) {
Ty->addTypeVars(TVs);
}
Arrow->ParamType->addTypeVars(TVs);
Arrow->ReturnType->addTypeVars(TVs);
break;
}
@ -229,12 +216,7 @@ namespace bolt {
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
for (auto Ty: Arrow->ParamTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return Arrow->ReturnType->hasTypeVar(TV);
return Arrow->ParamType->hasTypeVar(TV) || Arrow->ReturnType->hasTypeVar(TV);
}
case TypeKind::Con:
return false;
@ -308,7 +290,7 @@ namespace bolt {
case TypeIndexKind::TupleElement:
return llvm::cast<TTuple>(this)->ElementTypes[Index.I];
case TypeIndexKind::ArrowParamType:
return llvm::cast<TArrow>(this)->ParamTypes[Index.I];
return llvm::cast<TArrow>(this)->ParamType;
case TypeIndexKind::ArrowReturnType:
return llvm::cast<TArrow>(this)->ReturnType;
case TypeIndexKind::FieldType:
@ -445,13 +427,7 @@ namespace bolt {
TypeIndex Type::getStartIndex() {
switch (Kind) {
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
if (Arrow->ParamTypes.empty()) {
return TypeIndex::forArrowReturnType();
}
return TypeIndex::forArrowParamType(0);
}
return TypeIndex::forArrowParamType();
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);