Fix bug in inferencer and rename some variables

This commit is contained in:
Sam Vervaeck 2023-05-08 19:57:24 +02:00
parent 10f0ebae20
commit 936afd3be0
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY

View file

@ -18,25 +18,25 @@ namespace bolt {
break; break;
case TypeKind::Arrow: case TypeKind::Arrow:
{ {
auto Y = static_cast<TArrow*>(this); auto Arrow = static_cast<TArrow*>(this);
for (auto Ty: Y->ParamTypes) { for (auto Ty: Arrow->ParamTypes) {
Ty->addTypeVars(TVs); Ty->addTypeVars(TVs);
} }
Y->ReturnType->addTypeVars(TVs); Arrow->ReturnType->addTypeVars(TVs);
break; break;
} }
case TypeKind::Con: case TypeKind::Con:
{ {
auto Y = static_cast<TCon*>(this); auto Con = static_cast<TCon*>(this);
for (auto Ty: Y->Args) { for (auto Ty: Con->Args) {
Ty->addTypeVars(TVs); Ty->addTypeVars(TVs);
} }
break; break;
} }
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
auto Y = static_cast<TTuple*>(this); auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Y->ElementTypes) { for (auto Ty: Tuple->ElementTypes) {
Ty->addTypeVars(TVs); Ty->addTypeVars(TVs);
} }
break; break;
@ -52,18 +52,18 @@ namespace bolt {
return static_cast<TVar*>(this)->Id == TV->Id; return static_cast<TVar*>(this)->Id == TV->Id;
case TypeKind::Arrow: case TypeKind::Arrow:
{ {
auto Y = static_cast<TArrow*>(this); auto Arrow = static_cast<TArrow*>(this);
for (auto Ty: Y->ParamTypes) { for (auto Ty: Arrow->ParamTypes) {
if (Ty->hasTypeVar(TV)) { if (Ty->hasTypeVar(TV)) {
return true; return true;
} }
} }
return Y->ReturnType->hasTypeVar(TV); return Arrow->ReturnType->hasTypeVar(TV);
} }
case TypeKind::Con: case TypeKind::Con:
{ {
auto Y = static_cast<TCon*>(this); auto Con = static_cast<TCon*>(this);
for (auto Ty: Y->Args) { for (auto Ty: Con->Args) {
if (Ty->hasTypeVar(TV)) { if (Ty->hasTypeVar(TV)) {
return true; return true;
} }
@ -72,8 +72,8 @@ namespace bolt {
} }
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
auto Y = static_cast<TTuple*>(this); auto Tuple = static_cast<TTuple*>(this);
for (auto Ty: Y->ElementTypes) { for (auto Ty: Tuple->ElementTypes) {
if (Ty->hasTypeVar(TV)) { if (Ty->hasTypeVar(TV)) {
return true; return true;
} }
@ -89,24 +89,24 @@ namespace bolt {
switch (Kind) { switch (Kind) {
case TypeKind::Var: case TypeKind::Var:
{ {
auto Y = static_cast<TVar*>(this); auto TV = static_cast<TVar*>(this);
auto Match = Sub.find(Y); auto Match = Sub.find(TV);
return Match != Sub.end() ? Match->second->substitute(Sub) : Y; return Match != Sub.end() ? Match->second->substitute(Sub) : this;
} }
case TypeKind::Arrow: case TypeKind::Arrow:
{ {
auto Y = static_cast<TArrow*>(this); auto Arrow = static_cast<TArrow*>(this);
bool Changed = false; bool Changed = false;
std::vector<Type*> NewParamTypes; std::vector<Type*> NewParamTypes;
for (auto Ty: Y->ParamTypes) { for (auto Ty: Arrow->ParamTypes) {
auto NewParamType = Ty->substitute(Sub); auto NewParamType = Ty->substitute(Sub);
if (NewParamType != Ty) { if (NewParamType != Ty) {
Changed = true; Changed = true;
} }
NewParamTypes.push_back(NewParamType); NewParamTypes.push_back(NewParamType);
} }
auto NewRetTy = Y->ReturnType->substitute(Sub) ; auto NewRetTy = Arrow->ReturnType->substitute(Sub) ;
if (NewRetTy != Y->ReturnType) { if (NewRetTy != Arrow->ReturnType) {
Changed = true; Changed = true;
} }
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this; return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
@ -115,24 +115,24 @@ namespace bolt {
return this; return this;
case TypeKind::Con: case TypeKind::Con:
{ {
auto Y = static_cast<TCon*>(this); auto Con = static_cast<TCon*>(this);
bool Changed = false; bool Changed = false;
std::vector<Type*> NewArgs; std::vector<Type*> NewArgs;
for (auto Arg: Y->Args) { for (auto Arg: Con->Args) {
auto NewArg = Arg->substitute(Sub); auto NewArg = Arg->substitute(Sub);
if (NewArg != Arg) { if (NewArg != Arg) {
Changed = true; Changed = true;
} }
NewArgs.push_back(NewArg); NewArgs.push_back(NewArg);
} }
return Changed ? new TCon(Y->Id, NewArgs, Y->DisplayName) : this; return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
} }
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
auto Y = static_cast<TTuple*>(this); auto Tuple = static_cast<TTuple*>(this);
bool Changed = false; bool Changed = false;
std::vector<Type*> NewElementTypes; std::vector<Type*> NewElementTypes;
for (auto Ty: Y->ElementTypes) { for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = Ty->substitute(Sub); auto NewElementType = Ty->substitute(Sub);
if (NewElementType != Ty) { if (NewElementType != Ty) {
Changed = true; Changed = true;
@ -148,14 +148,14 @@ namespace bolt {
switch (Kind) { switch (Kind) {
case ConstraintKind::Equal: case ConstraintKind::Equal:
{ {
auto Y = static_cast<CEqual*>(this); auto Equal = static_cast<CEqual*>(this);
return new CEqual(Y->Left->substitute(Sub), Y->Right->substitute(Sub), Y->Source); return new CEqual(Equal->Left->substitute(Sub), Equal->Right->substitute(Sub), Equal->Source);
} }
case ConstraintKind::Many: case ConstraintKind::Many:
{ {
auto Y = static_cast<CMany*>(this); auto Many = static_cast<CMany*>(this);
auto NewConstraints = new ConstraintSet(); auto NewConstraints = new ConstraintSet();
for (auto Element: Y->Elements) { for (auto Element: Many->Elements) {
NewConstraints->push_back(Element->substitute(Sub)); NewConstraints->push_back(Element->substitute(Sub));
} }
return new CMany(*NewConstraints); return new CMany(*NewConstraints);
@ -212,19 +212,19 @@ namespace bolt {
return false; return false;
} }
void Checker::addConstraint(Constraint* Constraint) { void Checker::addConstraint(Constraint* C) {
switch (Constraint->getKind()) { switch (C->getKind()) {
case ConstraintKind::Equal: case ConstraintKind::Equal:
{ {
auto Y = static_cast<CEqual*>(Constraint); auto Y = static_cast<CEqual*>(C);
for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) { for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) {
auto& Ctx = **Iter; auto& Ctx = **Iter;
if (hasTypeVar(Ctx.TVs, Y->Left) || hasTypeVar(Ctx.TVs, Y->Right)) { if (hasTypeVar(Ctx.TVs, Y->Left) || hasTypeVar(Ctx.TVs, Y->Right)) {
Ctx.Constraints.push_back(Constraint); Ctx.Constraints.push_back(C);
return; return;
} }
} }
Contexts.front()->Constraints.push_back(Constraint); Contexts.front()->Constraints.push_back(C);
//auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth); //auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth);
//ZEN_ASSERT(I < Contexts.size()); //ZEN_ASSERT(I < Contexts.size());
//auto Ctx = Contexts[I]; //auto Ctx = Contexts[I];
@ -233,7 +233,7 @@ namespace bolt {
} }
case ConstraintKind::Many: case ConstraintKind::Many:
{ {
auto Y = static_cast<CMany*>(Constraint); auto Y = static_cast<CMany*>(C);
for (auto Element: Y->Elements) { for (auto Element: Y->Elements) {
addConstraint(Element); addConstraint(Element);
} }
@ -255,8 +255,8 @@ namespace bolt {
case NodeType::SourceFile: case NodeType::SourceFile:
{ {
auto Y = static_cast<SourceFile*>(X); auto File = static_cast<SourceFile*>(X);
for (auto Element: Y->Elements) { for (auto Element: File->Elements) {
forwardDeclare(Element) ; forwardDeclare(Element) ;
} }
break; break;
@ -264,30 +264,30 @@ namespace bolt {
case NodeType::LetDeclaration: case NodeType::LetDeclaration:
{ {
auto Y = static_cast<LetDeclaration*>(X); auto Let = static_cast<LetDeclaration*>(X);
auto NewCtx = new InferContext(); auto NewCtx = new InferContext();
Y->Ctx = NewCtx; Let->Ctx = NewCtx;
Contexts.push_back(NewCtx); Contexts.push_back(NewCtx);
Type* Ty; Type* Ty;
if (Y->TypeAssert) { if (Let->TypeAssert) {
Ty = inferTypeExpression(Y->TypeAssert->TypeExpression); Ty = inferTypeExpression(Let->TypeAssert->TypeExpression);
} else { } else {
Ty = createTypeVar(); Ty = createTypeVar();
} }
Y->Ty = Ty; Let->Ty = Ty;
if (Y->Body) { if (Let->Body) {
switch (Y->Body->Type) { switch (Let->Body->Type) {
case NodeType::LetExprBody: case NodeType::LetExprBody:
break; break;
case NodeType::LetBlockBody: case NodeType::LetBlockBody:
{ {
auto Z = static_cast<LetBlockBody*>(Y->Body); auto Block = static_cast<LetBlockBody*>(Let->Body);
NewCtx->ReturnType = createTypeVar(); NewCtx->ReturnType = createTypeVar();
for (auto Element: Z->Elements) { for (auto Element: Block->Elements) {
forwardDeclare(Element); forwardDeclare(Element);
} }
break; break;
@ -299,7 +299,7 @@ namespace bolt {
Contexts.pop_back(); Contexts.pop_back();
inferBindings(Y->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs); inferBindings(Let->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs);
break; break;
@ -318,8 +318,8 @@ namespace bolt {
case NodeType::SourceFile: case NodeType::SourceFile:
{ {
auto Y = static_cast<SourceFile*>(X); auto File = static_cast<SourceFile*>(X);
for (auto Element: Y->Elements) { for (auto Element: File->Elements) {
infer(Element); infer(Element);
} }
break; break;
@ -327,8 +327,8 @@ namespace bolt {
case NodeType::IfStatement: case NodeType::IfStatement:
{ {
auto Y = static_cast<IfStatement*>(X); auto IfStmt = static_cast<IfStatement*>(X);
for (auto Part: Y->Parts) { for (auto Part: IfStmt->Parts) {
if (Part->Test != nullptr) { if (Part->Test != nullptr) {
addConstraint(new CEqual { BoolType, inferExpression(Part->Test), Part->Test }); addConstraint(new CEqual { BoolType, inferExpression(Part->Test), Part->Test });
} }
@ -341,15 +341,15 @@ namespace bolt {
case NodeType::LetDeclaration: case NodeType::LetDeclaration:
{ {
auto Y = static_cast<LetDeclaration*>(X); auto LetDecl = static_cast<LetDeclaration*>(X);
auto NewCtx = Y->Ctx; auto NewCtx = LetDecl->Ctx;
Contexts.push_back(NewCtx); Contexts.push_back(NewCtx);
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
Type* RetType; Type* RetType;
for (auto Param: Y->Params) { for (auto Param: LetDecl->Params) {
// TODO incorporate Param->TypeAssert or make it a kind of pattern // TODO incorporate Param->TypeAssert or make it a kind of pattern
TVar* TV = createTypeVar(); TVar* TV = createTypeVar();
TVSet NoTVs; TVSet NoTVs;
@ -358,19 +358,19 @@ namespace bolt {
ParamTypes.push_back(TV); ParamTypes.push_back(TV);
} }
if (Y->Body) { if (LetDecl->Body) {
switch (Y->Body->Type) { switch (LetDecl->Body->Type) {
case NodeType::LetExprBody: case NodeType::LetExprBody:
{ {
auto Z = static_cast<LetExprBody*>(Y->Body); auto Expr = static_cast<LetExprBody*>(LetDecl->Body);
RetType = inferExpression(Z->Expression); RetType = inferExpression(Expr->Expression);
break; break;
} }
case NodeType::LetBlockBody: case NodeType::LetBlockBody:
{ {
auto Z = static_cast<LetBlockBody*>(Y->Body); auto Block = static_cast<LetBlockBody*>(LetDecl->Body);
RetType = Y->Ty; RetType = createTypeVar();
for (auto Element: Z->Elements) { for (auto Element: Block->Elements) {
infer(Element); infer(Element);
} }
break; break;
@ -382,7 +382,7 @@ namespace bolt {
RetType = createTypeVar(); RetType = createTypeVar();
} }
addConstraint(new CEqual { Y->Ty, new TArrow(ParamTypes, RetType), X }); addConstraint(new CEqual { LetDecl->Ty, new TArrow(ParamTypes, RetType), X });
Contexts.pop_back(); Contexts.pop_back();
@ -391,10 +391,10 @@ namespace bolt {
case NodeType::ReturnStatement: case NodeType::ReturnStatement:
{ {
auto Y = static_cast<ReturnStatement*>(X); auto RetStmt = static_cast<ReturnStatement*>(X);
Type* ReturnType; Type* ReturnType;
if (Y->Expression) { if (RetStmt->Expression) {
ReturnType = inferExpression(Y->Expression); ReturnType = inferExpression(RetStmt->Expression);
} else { } else {
ReturnType = new TTuple({}); ReturnType = new TTuple({});
} }
@ -404,8 +404,8 @@ namespace bolt {
case NodeType::ExpressionStatement: case NodeType::ExpressionStatement:
{ {
auto Y = static_cast<ExpressionStatement*>(X); auto ExprStmt = static_cast<ExpressionStatement*>(X);
inferExpression(Y->Expression); inferExpression(ExprStmt->Expression);
break; break;
} }
@ -464,10 +464,10 @@ namespace bolt {
case NodeType::ReferenceTypeExpression: case NodeType::ReferenceTypeExpression:
{ {
auto Y = static_cast<ReferenceTypeExpression*>(X); auto RefTE = static_cast<ReferenceTypeExpression*>(X);
auto Ty = lookupMono(Y->Name->Name->Text); auto Ty = lookupMono(RefTE->Name->Name->Text);
if (Ty == nullptr) { if (Ty == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name->Name); DE.add<BindingNotFoundDiagnostic>(RefTE->Name->Name->Text, RefTE->Name->Name);
return new TAny(); return new TAny();
} }
Mapping[X] = Ty; Mapping[X] = Ty;
@ -476,12 +476,12 @@ namespace bolt {
case NodeType::ArrowTypeExpression: case NodeType::ArrowTypeExpression:
{ {
auto Y = static_cast<ArrowTypeExpression*>(X); auto ArrowTE = static_cast<ArrowTypeExpression*>(X);
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
for (auto ParamType: Y->ParamTypes) { for (auto ParamType: ArrowTE->ParamTypes) {
ParamTypes.push_back(inferTypeExpression(ParamType)); ParamTypes.push_back(inferTypeExpression(ParamType));
} }
auto ReturnType = inferTypeExpression(Y->ReturnType); auto ReturnType = inferTypeExpression(ArrowTE->ReturnType);
auto Ty = new TArrow(ParamTypes, ReturnType); auto Ty = new TArrow(ParamTypes, ReturnType);
Mapping[X] = Ty; Mapping[X] = Ty;
return Ty; return Ty;
@ -499,9 +499,9 @@ namespace bolt {
case NodeType::ConstantExpression: case NodeType::ConstantExpression:
{ {
auto Y = static_cast<ConstantExpression*>(X); auto Const = static_cast<ConstantExpression*>(X);
Type* Ty = nullptr; Type* Ty = nullptr;
switch (Y->Token->Type) { switch (Const->Token->Type) {
case NodeType::IntegerLiteral: case NodeType::IntegerLiteral:
Ty = lookupMono("Int"); Ty = lookupMono("Int");
break; break;
@ -518,17 +518,17 @@ namespace bolt {
case NodeType::ReferenceExpression: case NodeType::ReferenceExpression:
{ {
auto Y = static_cast<ReferenceExpression*>(X); auto Ref = static_cast<ReferenceExpression*>(X);
ZEN_ASSERT(Y->Name->ModulePath.empty()); ZEN_ASSERT(Ref->Name->ModulePath.empty());
auto Ctx = lookupCall(Y, Y->Name->getSymbolPath()); auto Ctx = lookupCall(Ref, Ref->Name->getSymbolPath());
if (Ctx) { if (Ctx) {
/* std::cerr << "recursive call!\n"; */ /* std::cerr << "recursive call!\n"; */
ZEN_ASSERT(Ctx->ReturnType != nullptr); ZEN_ASSERT(Ctx->ReturnType != nullptr);
return Ctx->ReturnType; return Ctx->ReturnType;
} }
auto Scm = lookup(Y->Name->Name->Text); auto Scm = lookup(Ref->Name->Name->Text);
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Y->Name->Name->Text, Y->Name); DE.add<BindingNotFoundDiagnostic>(Ref->Name->Name->Text, Ref->Name);
return new TAny(); return new TAny();
} }
auto Ty = instantiate(*Scm, X); auto Ty = instantiate(*Scm, X);
@ -538,11 +538,11 @@ namespace bolt {
case NodeType::CallExpression: case NodeType::CallExpression:
{ {
auto Y = static_cast<CallExpression*>(X); auto Call = static_cast<CallExpression*>(X);
auto OpTy = inferExpression(Y->Function); auto OpTy = inferExpression(Call->Function);
auto RetType = createTypeVar(); auto RetType = createTypeVar();
std::vector<Type*> ArgTypes; std::vector<Type*> ArgTypes;
for (auto Arg: Y->Args) { for (auto Arg: Call->Args) {
ArgTypes.push_back(inferExpression(Arg)); ArgTypes.push_back(inferExpression(Arg));
} }
addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X }); addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X });
@ -552,17 +552,17 @@ namespace bolt {
case NodeType::InfixExpression: case NodeType::InfixExpression:
{ {
auto Y = static_cast<InfixExpression*>(X); auto Infix = static_cast<InfixExpression*>(X);
auto Scm = lookup(Y->Operator->getText()); auto Scm = lookup(Infix->Operator->getText());
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator); DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator);
return new TAny(); return new TAny();
} }
auto OpTy = instantiate(*Scm, Y->Operator); auto OpTy = instantiate(*Scm, Infix->Operator);
auto RetTy = createTypeVar(); auto RetTy = createTypeVar();
std::vector<Type*> ArgTys; std::vector<Type*> ArgTys;
ArgTys.push_back(inferExpression(Y->LHS)); ArgTys.push_back(inferExpression(Infix->LHS));
ArgTys.push_back(inferExpression(Y->RHS)); ArgTys.push_back(inferExpression(Infix->RHS));
addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X }); addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X });
Mapping[X] = RetTy; Mapping[X] = RetTy;
return RetTy; return RetTy;
@ -570,8 +570,8 @@ namespace bolt {
case NodeType::NestedExpression: case NodeType::NestedExpression:
{ {
auto Y = static_cast<NestedExpression*>(X); auto Nested = static_cast<NestedExpression*>(X);
return inferExpression(Y->Inner); return inferExpression(Nested->Inner);
} }
default: default:
@ -636,8 +636,8 @@ namespace bolt {
case ConstraintKind::Many: case ConstraintKind::Many:
{ {
auto Y = static_cast<CMany*>(Constraint); auto Many = static_cast<CMany*>(Constraint);
for (auto Constraint: Y->Elements) { for (auto Constraint: Many->Elements) {
Queue.push(Constraint); Queue.push(Constraint);
} }
break; break;
@ -645,10 +645,10 @@ namespace bolt {
case ConstraintKind::Equal: case ConstraintKind::Equal:
{ {
auto Y = static_cast<CEqual*>(Constraint); auto Equal = static_cast<CEqual*>(Constraint);
/* std::cerr << describe(Y->Left) << " ~ " << describe(Y->Right) << std::endl; */ std::cerr << describe(Equal->Left) << " ~ " << describe(Equal->Right) << std::endl;
if (!unify(Y->Left, Y->Right, Solution)) { if (!unify(Equal->Left, Equal->Right, Solution)) {
DE.add<UnificationErrorDiagnostic>(Y->Left->substitute(Solution), Y->Right->substitute(Solution), Y->Source); DE.add<UnificationErrorDiagnostic>(Equal->Left->substitute(Solution), Equal->Right->substitute(Solution), Equal->Source);
} }
break; break;
} }
@ -661,27 +661,29 @@ namespace bolt {
bool Checker::unify(Type* A, Type* B, TVSub& Solution) { bool Checker::unify(Type* A, Type* B, TVSub& Solution) {
if (A->getKind() == TypeKind::Var) { while (A->getKind() == TypeKind::Var) {
auto Match = Solution.find(static_cast<TVar*>(A)); auto Match = Solution.find(static_cast<TVar*>(A));
if (Match != Solution.end()) { if (Match == Solution.end()) {
A = Match->second; break;
} }
A = Match->second;
} }
if (B->getKind() == TypeKind::Var) { while (B->getKind() == TypeKind::Var) {
auto Match = Solution.find(static_cast<TVar*>(B)); auto Match = Solution.find(static_cast<TVar*>(B));
if (Match != Solution.end()) { if (Match == Solution.end()) {
B = Match->second; break;
} }
B = Match->second;
} }
if (A->getKind() == TypeKind::Var) { if (A->getKind() == TypeKind::Var) {
auto Y = static_cast<TVar*>(A); auto TV = static_cast<TVar*>(A);
if (B->hasTypeVar(Y)) { if (B->hasTypeVar(TV)) {
// TODO occurs check // TODO occurs check
return false; return false;
} }
Solution[Y] = B; Solution[TV] = B;
return true; return true;
} }
@ -694,24 +696,24 @@ namespace bolt {
} }
if (A->getKind() == TypeKind::Arrow && B->getKind() == TypeKind::Arrow) { if (A->getKind() == TypeKind::Arrow && B->getKind() == TypeKind::Arrow) {
auto Y = static_cast<TArrow*>(A); auto Arr1 = static_cast<TArrow*>(A);
auto Z = static_cast<TArrow*>(B); auto Arr2 = static_cast<TArrow*>(B);
if (Y->ParamTypes.size() != Z->ParamTypes.size()) { if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) {
return false; return false;
} }
auto Count = Y->ParamTypes.size(); auto Count = Arr1->ParamTypes.size();
for (std::size_t I = 0; I < Count; I++) { for (std::size_t I = 0; I < Count; I++) {
if (!unify(Y->ParamTypes[I], Z->ParamTypes[I], Solution)) { if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) {
return false; return false;
} }
} }
return unify(Y->ReturnType, Z->ReturnType, Solution); return unify(Arr1->ReturnType, Arr2->ReturnType, Solution);
} }
if (A->getKind() == TypeKind::Arrow) { if (A->getKind() == TypeKind::Arrow) {
auto Y = static_cast<TArrow*>(A); auto Arr = static_cast<TArrow*>(A);
if (Y->ParamTypes.empty()) { if (Arr->ParamTypes.empty()) {
return unify(Y->ReturnType, B, Solution); return unify(Arr->ReturnType, B, Solution);
} }
} }
@ -720,15 +722,15 @@ namespace bolt {
} }
if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) { if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) {
auto Y = static_cast<TTuple*>(A); auto Tuple1 = static_cast<TTuple*>(A);
auto Z = static_cast<TTuple*>(B); auto Tuple2 = static_cast<TTuple*>(B);
if (Y->ElementTypes.size() != Z->ElementTypes.size()) { if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) {
return false; return false;
} }
auto Count = Y->ElementTypes.size(); auto Count = Tuple1->ElementTypes.size();
bool Success = true; bool Success = true;
for (size_t I = 0; I < Count; I++) { for (size_t I = 0; I < Count; I++) {
if (!unify(Y->ElementTypes[I], Z->ElementTypes[I], Solution)) { if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], Solution)) {
Success = false; Success = false;
} }
} }
@ -736,15 +738,15 @@ namespace bolt {
} }
if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) { if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) {
auto Y = static_cast<TCon*>(A); auto Con1 = static_cast<TCon*>(A);
auto Z = static_cast<TCon*>(B); auto Con2 = static_cast<TCon*>(B);
if (Y->Id != Z->Id) { if (Con1->Id != Con2->Id) {
return false; return false;
} }
ZEN_ASSERT(Y->Args.size() == Z->Args.size()); ZEN_ASSERT(Con1->Args.size() == Con2->Args.size());
auto Count = Y->Args.size(); auto Count = Con1->Args.size();
for (std::size_t I = 0; I < Count; I++) { for (std::size_t I = 0; I < Count; I++) {
if (!unify(Y->Args[I], Z->Args[I], Solution)) { if (!unify(Con1->Args[I], Con2->Args[I], Solution)) {
return false; return false;
} }
} }