diff --git a/src/checker.ts b/src/checker.ts index a1fdafb53..330e2921d 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -386,6 +386,7 @@ export class Checker { private boolType = new TCon(this.nextConTypeId++, [], 'Bool'); private contexts: InferContext[] = []; + private constraints: Constraint[] = []; private solution = new TVSub(); @@ -415,27 +416,9 @@ export class Checker { } private addConstraint(constraint: Constraint): void { - switch (constraint.kind) { - case ConstraintKind.Many: - { - for (const element of constraint.elements) { - this.addConstraint(element); - } - return; - } - case ConstraintKind.Equal: - { - const count = this.contexts.length; - let i; - for (i = count-1; i > 0; i--) { - const typeVars = this.contexts[i].typeVars; - if (typeVars.intersectsType(constraint.left) || typeVars.intersectsType(constraint.right)) { - break; - } - } - this.contexts[i].constraints.push(constraint); - break; - } + this.constraints.push(constraint); + if (this.contexts.length > 0) { + this.contexts[this.contexts.length-1].constraints.push(constraint); } } @@ -909,7 +892,7 @@ export class Checker { new CEqual( this.inferTypeExpression(node.typeAssert.typeExpression), type, - node.typeAssert + node ) ); } @@ -976,7 +959,7 @@ export class Checker { this.popContext(context); - this.solve(new CMany(constraints), this.solution); + this.solve(new CMany(this.constraints), this.solution); } private solve(constraint: Constraint, solution: TVSub): void { @@ -999,14 +982,8 @@ export class Checker { case ConstraintKind.Equal: { - if (!this.unify(constraint.left, constraint.right, solution)) { - this.diagnostics.add( - new UnificationFailedDiagnostic( - constraint.left.substitute(solution), - constraint.right.substitute(solution), - [...constraint.getNodes()], - ) - ); + if (!this.unify(constraint.left, constraint.right, solution, constraint)) { + // TODO break or continue? } break; } @@ -1016,7 +993,7 @@ export class Checker { } - private unify(left: Type, right: Type, solution: TVSub): boolean { + private unify(left: Type, right: Type, solution: TVSub, constraint: CEqual): boolean { if (left.kind === TypeKind.Var && solution.has(left)) { left = solution.get(left)!; @@ -1035,7 +1012,7 @@ export class Checker { } if (right.kind === TypeKind.Var) { - return this.unify(right, left, solution); + return this.unify(right, left, solution, constraint); } if (left.kind === TypeKind.Any || right.kind === TypeKind.Any) { @@ -1050,38 +1027,45 @@ export class Checker { let success = true; const count = left.paramTypes.length; for (let i = 0; i < count; i++) { - if (!this.unify(left.paramTypes[i], right.paramTypes[i], solution)) { + if (!this.unify(left.paramTypes[i], right.paramTypes[i], solution, constraint)) { success = false; } } - if (!this.unify(left.returnType, right.returnType, solution)) { + if (!this.unify(left.returnType, right.returnType, solution, constraint)) { success = false; } return success; } if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) { - return this.unify(left.returnType, right, solution); + return this.unify(left.returnType, right, solution, constraint); } if (right.kind === TypeKind.Arrow) { - return this.unify(right, left, solution); + return this.unify(right, left, solution, constraint); } if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) { - if (left.id !== right.id) { - return false; - } - assert(left.argTypes.length === right.argTypes.length); - const count = left.argTypes.length; - for (let i = 0; i < count; i++) { - if (!this.unify(left.argTypes[i], right.argTypes[i], solution)) { - return false; + if (left.id === right.id) { + assert(left.argTypes.length === right.argTypes.length); + const count = left.argTypes.length; + let success = true; + for (let i = 0; i < count; i++) { + if (!this.unify(left.argTypes[i], right.argTypes[i], solution, constraint)) { + success = false; + } } + return success; } - return true; } + this.diagnostics.add( + new UnificationFailedDiagnostic( + left.substitute(solution), + right.substitute(solution), + [...constraint.getNodes()], + ) + ); return false; }