From 91a4872c34bb5deb0ac101db13e32bee487bfd7a Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Thu, 22 Jun 2023 16:19:51 +0200 Subject: [PATCH] Merge solver back with checker and apply algorithm for eager constraint solving --- compiler/src/checker.ts | 386 ++++++++++++++++++++++++++++++++++++---- compiler/src/solver.ts | 312 -------------------------------- 2 files changed, 349 insertions(+), 349 deletions(-) delete mode 100644 compiler/src/solver.ts diff --git a/compiler/src/checker.ts b/compiler/src/checker.ts index 0f4a9f576..1e9d0c084 100644 --- a/compiler/src/checker.ts +++ b/compiler/src/checker.ts @@ -29,12 +29,13 @@ import { ModuleNotFoundDiagnostic, TypeclassNotFoundDiagnostic, TypeclassDeclaredTwiceDiagnostic, + FieldNotFoundDiagnostic, + TypeMismatchDiagnostic, } from "./diagnostics"; import { assert, assertNever, isEmpty, MultiMap, toStringTag, InspectFn, implementationLimitation } from "./util"; import { Analyser } from "./analysis"; import { InspectOptions } from "util"; -import { ConstraintSolver } from "./solver"; -import { TypeKind, TApp, TArrow, TCon, TField, TNil, TNominal, TPresent, TTuple, TVar, TVSet, TVSub, Type } from "./types"; +import { TypeKind, TApp, TArrow, TCon, TField, TNil, TNominal, TPresent, TTuple, TVar, TVSet, TVSub, Type, TypeBase, TAbsent } from "./types"; import { CClass, CEmpty, CEqual, CMany, Constraint, ConstraintKind, ConstraintSet } from "./constraints"; // export class Qual { @@ -363,9 +364,25 @@ function isFunctionDeclarationLike(node: LetDeclaration): boolean { // && (node.params.length > 0 || (node.body !== null && node.body.kind === SyntaxKind.BlockBody)); } +function* integers(start: number = 0) { + let i = start; + for (;;) { + yield i++; + } +} + +function hasTypeVar(typeVars: TVSet, type: Type): boolean { + for (const tv of type.getTypeVars()) { + if (typeVars.has(tv)) { + return true; + } + } + return false; +} + export class Checker { - private nextTypeVarId = 0; + private typeVarIds = integers(); private nextKindVarId = 0; private nextConTypeId = 0; @@ -391,8 +408,8 @@ export class Checker { this.globalKindEnv.set('String', new KType()); this.globalKindEnv.set('Bool', new KType()); - const a = new TVar(this.nextTypeVarId++); - const b = new TVar(this.nextTypeVarId++); + const a = new TVar(this.typeVarIds.next().value!); + const b = new TVar(this.typeVarIds.next().value!); this.globalTypeEnv.add('$', Forall.fromArrays([ a, b ], [], new TArrow(new TArrow(new TArrow(a, b), a), b)), Symkind.Var); this.globalTypeEnv.add('String', Forall.fromArrays([], [], this.stringType), Symkind.Type); @@ -426,7 +443,7 @@ export class Checker { } private createTypeVar(node: Syntax | null = null): TVar { - const typeVar = new TVar(this.nextTypeVarId++, node); + const typeVar = new TVar(this.typeVarIds.next().value!, node); this.getContext().typeVars.add(typeVar); return typeVar; } @@ -436,7 +453,50 @@ export class Checker { } private addConstraint(constraint: Constraint): void { - this.getContext().constraints.push(constraint); + switch (constraint.kind) { + case ConstraintKind.Empty: + break; + case ConstraintKind.Many: + for (const element of constraint.elements) { + this.addConstraint(element); + } + break; + case ConstraintKind.Equal: + { + const global = 0; + let maxLevelLeft = global; + for (let i = this.contexts.length; i-- > 0;) { + const ctx = this.contexts[i]; + if (hasTypeVar(ctx.typeVars, constraint.left)) { + maxLevelLeft = i; + break; + } + } + let maxLevelRight = global; + for (let i = this.contexts.length; i-- > 0;) { + const ctx = this.contexts[i]; + if (hasTypeVar(ctx.typeVars, constraint.left)) { + maxLevelRight = i; + break; + } + } + const upperLevel = Math.max(maxLevelLeft, maxLevelRight); + let lowerLevel = upperLevel; + for (let i = 0; i < this.contexts.length; i++) { + const ctx = this.contexts[i]; + if (hasTypeVar(ctx.typeVars, constraint.left) || hasTypeVar(ctx.typeVars, constraint.right)) { + lowerLevel = i; + break; + } + } + if (upperLevel == lowerLevel || maxLevelLeft == global || maxLevelRight == global) { + this.solve(constraint); + } else { + this.contexts[upperLevel].constraints.push(constraint); + } + break; + } + } } private pushContext(context: InferContext) { @@ -1757,21 +1817,18 @@ export class Checker { this.forwardDeclareKind(sourceFile, kenv); this.inferKind(sourceFile, kenv); + this.initialize(sourceFile, this.globalTypeEnv); + const typeVars = new TVSet(); const constraints = new ConstraintSet(); - const env = new TypeEnv(this.globalTypeEnv); - const context: InferContext = { typeVars, constraints, env, returnType: null }; - - this.pushContext(context); - - this.initialize(sourceFile, env); - - this.pushContext({ + const sourceFileCtx = { typeVars, constraints, env: sourceFile.typeEnv!, returnType: null - }); + }; + + this.pushContext(sourceFileCtx); const sccs = [...this.analyser.getSortedDeclarations()]; @@ -1831,7 +1888,7 @@ export class Checker { // this.addConstraint(new CEqual(type, other.inferredType!, node)); // } - this.contexts.pop(); + this.popContext(innerCtx); if (node.parent!.kind !== SyntaxKind.InstanceDeclaration) { const scopeDecl = node.parent!.getScope().node; @@ -1851,14 +1908,284 @@ export class Checker { this.infer(sourceFile); - this.contexts.pop(); - this.popContext(context); + this.popContext(sourceFileCtx); - const solver = new ConstraintSolver(this.diagnostics, this.nextTypeVarId); + this.solve(new CMany(constraints)); - solver.solve(new CMany(constraints)); + } - this.solution = solver.solution; + private path: string[] = []; + private constraint: Constraint | null = null; + private maxTypeErrorCount = 5; + + private find(type: Type): Type { + while (type.kind === TypeKind.Var && this.solution.has(type)) { + type = this.solution.get(type)!; + } + return type; + } + + private unifyField(left: Type, right: Type, enableDiagnostics: boolean): boolean { + + const swap = () => { [right, left] = [left, right]; } + + if (left.kind === TypeKind.Absent && right.kind === TypeKind.Absent) { + return true; + } + + if (right.kind === TypeKind.Absent) { + swap(); + } + + if (left.kind === TypeKind.Absent) { + assert(right.kind === TypeKind.Present); + const fieldName = this.path[this.path.length-1]; + if (enableDiagnostics) { + this.diagnostics.add( + new FieldNotFoundDiagnostic(fieldName, left.node, right.type.node, this.constraint!.firstNode) + ); + } + return false; + } + + assert(left.kind === TypeKind.Present && right.kind === TypeKind.Present); + return this.unify(left.type, right.type, enableDiagnostics); + } + + + private unify(left: Type, right: Type, enableDiagnostics: boolean): boolean { + + left = this.find(left); + right = this.find(right); + + // console.log(`unify ${describeType(left)} @ ${left.node && left.node.constructor && left.node.constructor.name} ~ ${describeType(right)} @ ${right.node && right.node.constructor && right.node.constructor.name}`); + + const swap = () => { [right, left] = [left, right]; } + + if (left.kind !== TypeKind.Var && right.kind === TypeKind.Var) { + swap(); + } + + if (left.kind === TypeKind.Var) { + + // Perform an occurs check, verifying whether left occurs + // somewhere inside the structure of right. If so, unification + // makes no sense. + if (right.hasTypeVar(left)) { + // TODO print a diagnostic + return false; + } + + // We are ready to join the types, so the first thing we do is + // propagating the type classes that 'left' requires to 'right'. + // If 'right' is another type variable, we're lucky. We just copy + // the missing type classes from 'left' to 'right'. Otherwise, + //const propagateClasses = (classes: Iterable, type: Type) => { + // if (type.kind === TypeKind.Var) { + // for (const constraint of classes) { + // type.context.add(constraint); + // } + // } else if (type.kind === TypeKind.Con) { + // for (const constraint of classes) { + // propagateClassTCon(constraint, type); + // } + // } else { + // //assert(false); + // //this.diagnostics.add(new ); + // } + //} + + //const propagateClassTCon = (clazz: ClassDeclaration, type: TCon) => { + // const s = this.findInstanceContext(type, clazz); + // let i = 0; + // for (const classes of s) { + // propagateClasses(classes, type.argTypes[i++]); + // } + //} + + //propagateClasses(left.context, right); + + // We are all clear; set the actual type of left to right. + this.solution.set(left, right); + + // This is a very specific adjustment that is critical to the + // well-functioning of the infer/unify algorithm. When addConstraint() is + // called, it may decide to solve the constraint immediately during + // inference. If this happens, a type variable might get assigned a concrete + // type such as Int. We therefore never want the variable to be polymorphic + // and be instantiated with a fresh variable, as that would allow Bool to + // collide with Int. + // + // Should it get assigned another unification variable, that's OK too + // because then that variable is what matters and it will become the new + // (possibly polymorphic) variable. + if (this.contexts.length > 0) { + this.contexts[this.contexts.length-1].typeVars.delete(left); + } + + // These types will be join, and we'd like to track that + // into a special chain. + TypeBase.join(left, right); + + // if (left.node !== null) { + // right.node = left.node; + // } + + return true; + } + + if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) { + let success = true; + if (!this.unify(left.paramType, right.paramType, enableDiagnostics)) { + success = false; + } + if (!this.unify(left.returnType, right.returnType, enableDiagnostics)) { + success = false; + } + if (success) { + TypeBase.join(left, right); + } + return success; + } + + if (left.kind === TypeKind.Tuple && right.kind === TypeKind.Tuple) { + if (left.elementTypes.length === right.elementTypes.length) { + let success = false; + const count = left.elementTypes.length; + for (let i = 0; i < count; i++) { + if (!this.unify(left.elementTypes[i], right.elementTypes[i], enableDiagnostics)) { + success = false; + } + } + if (success) { + TypeBase.join(left, right); + } + return success; + } + } + + if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) { + if (left.id === right.id) { + assert(left.argTypes.length === right.argTypes.length); + const count = left.argTypes.length; + let success = true; + for (let i = 0; i < count; i++) { + if (!this.unify(left.argTypes[i], right.argTypes[i], enableDiagnostics)) { + success = false; + } + } + if (success) { + TypeBase.join(left, right); + } + return success; + } + } + + if (left.kind === TypeKind.Nil && right.kind === TypeKind.Nil) { + return true; + } + + if (left.kind === TypeKind.Field && right.kind === TypeKind.Field) { + if (left.name === right.name) { + let success = true; + this.path.push(left.name); + if (!this.unifyField(left.type, right.type, enableDiagnostics)) { + success = false; + } + this.path.pop(); + if (!this.unify(left.restType, right.restType, enableDiagnostics)) { + success = false; + } + return success; + } + let success = true; + const newRestType = new TVar(this.typeVarIds.next().value!); + if (!this.unify(left.restType, new TField(right.name, right.type, newRestType), enableDiagnostics)) { + success = false; + } + if (!this.unify(right.restType, new TField(left.name, left.type, newRestType), enableDiagnostics)) { + success = false; + } + return success; + } + + if (left.kind === TypeKind.Nil && right.kind === TypeKind.Field) { + swap(); + } + + if (left.kind === TypeKind.Field && right.kind === TypeKind.Nil) { + let success = true; + this.path.push(left.name); + if (!this.unifyField(left.type, new TAbsent(right.node), enableDiagnostics)) { + success = false; + } + this.path.pop(); + if (!this.unify(left.restType, right, enableDiagnostics)) { + success = false; + } + return success + } + + if (left.kind === TypeKind.Nominal && right.kind === TypeKind.Nominal) { + if (left.decl === right.decl) { + return true; + } + // fall through to error reporting + } + + if (left.kind === TypeKind.App && right.kind === TypeKind.App) { + return this.unify(left.left, right.left, enableDiagnostics) + && this.unify(left.right, right.right, enableDiagnostics); + } + + if (enableDiagnostics) { + this.diagnostics.add( + new TypeMismatchDiagnostic( + left.substitute(this.solution), + right.substitute(this.solution), + [...this.constraint!.getNodes()], + this.path, + ) + ); + } + return false; + } + + public solve(constraint: Constraint): void { + + let queue = [ constraint ]; + + let errorCount = 0; + + while (queue.length > 0) { + + const constraint = queue.shift()!; + + switch (constraint.kind) { + + case ConstraintKind.Many: + { + for (const element of constraint.elements) { + queue.push(element); + } + break; + } + + case ConstraintKind.Equal: + { + this.constraint = constraint; + if (!this.unify(constraint.left, constraint.right, true)) { + errorCount++; + if (errorCount === this.maxTypeErrorCount) { + return; + } + } + break; + } + + } + + } } @@ -1905,18 +2232,3 @@ function getVariadicMember(node: StructPattern) {1713 } return null; } - -type HasTypeEnv - = ClassDeclaration - | InstanceDeclaration - | LetDeclaration - | ModuleDeclaration - | SourceFile - -function shouldChangeTypeEnvDuringVisit(node: Syntax): node is HasTypeEnv { - return node.kind === SyntaxKind.ClassDeclaration - || node.kind === SyntaxKind.InstanceDeclaration - || node.kind === SyntaxKind.ModuleDeclaration - || node.kind === SyntaxKind.SourceFile -} - diff --git a/compiler/src/solver.ts b/compiler/src/solver.ts deleted file mode 100644 index 94f4197a5..000000000 --- a/compiler/src/solver.ts +++ /dev/null @@ -1,312 +0,0 @@ -import { Constraint, ConstraintKind } from "./constraints"; -import { Diagnostics, FieldNotFoundDiagnostic, TypeclassNotFoundDiagnostic, TypeclassNotImplementedDiagnostic, TypeMismatchDiagnostic } from "./diagnostics"; -import { TAbsent, TField, TVar, TVSub, Type, TypeBase, TypeKind } from "./types"; -import { assert } from "./util"; - -export class ConstraintSolver { - - private path: string[] = []; - private constraint: Constraint | null = null; - private maxTypeErrorCount = 5; - - public solution = new TVSub; - - public constructor( - public diagnostics: Diagnostics, - private nextTypeVarId: number, - ) { - - } - - private find(type: Type): Type { - while (type.kind === TypeKind.Var && this.solution.has(type)) { - type = this.solution.get(type)!; - } - return type; - } - - private unifyField(left: Type, right: Type, enableDiagnostics: boolean): boolean { - - const swap = () => { [right, left] = [left, right]; } - - if (left.kind === TypeKind.Absent && right.kind === TypeKind.Absent) { - return true; - } - - if (right.kind === TypeKind.Absent) { - swap(); - } - - if (left.kind === TypeKind.Absent) { - assert(right.kind === TypeKind.Present); - const fieldName = this.path[this.path.length-1]; - if (enableDiagnostics) { - this.diagnostics.add( - new FieldNotFoundDiagnostic(fieldName, left.node, right.type.node, this.constraint!.firstNode) - ); - } - return false; - } - - assert(left.kind === TypeKind.Present && right.kind === TypeKind.Present); - return this.unify(left.type, right.type, enableDiagnostics); - } - - - private unify(left: Type, right: Type, enableDiagnostics: boolean): boolean { - - left = this.find(left); - right = this.find(right); - - // console.log(`unify ${describeType(left)} @ ${left.node && left.node.constructor && left.node.constructor.name} ~ ${describeType(right)} @ ${right.node && right.node.constructor && right.node.constructor.name}`); - - const swap = () => { [right, left] = [left, right]; } - - if (left.kind !== TypeKind.Var && right.kind === TypeKind.Var) { - swap(); - } - - if (left.kind === TypeKind.Var) { - - // Perform an occurs check, verifying whether left occurs - // somewhere inside the structure of right. If so, unification - // makes no sense. - if (right.hasTypeVar(left)) { - // TODO print a diagnostic - return false; - } - - // We are ready to join the types, so the first thing we do is - // propagating the type classes that 'left' requires to 'right'. - // If 'right' is another type variable, we're lucky. We just copy - // the missing type classes from 'left' to 'right'. Otherwise, - //const propagateClasses = (classes: Iterable, type: Type) => { - // if (type.kind === TypeKind.Var) { - // for (const constraint of classes) { - // type.context.add(constraint); - // } - // } else if (type.kind === TypeKind.Con) { - // for (const constraint of classes) { - // propagateClassTCon(constraint, type); - // } - // } else { - // //assert(false); - // //this.diagnostics.add(new ); - // } - //} - - //const propagateClassTCon = (clazz: ClassDeclaration, type: TCon) => { - // const s = this.findInstanceContext(type, clazz); - // let i = 0; - // for (const classes of s) { - // propagateClasses(classes, type.argTypes[i++]); - // } - //} - - //propagateClasses(left.context, right); - - // We are all clear; set the actual type of left to right. - this.solution.set(left, right); - - // These types will be join, and we'd like to track that - // into a special chain. - TypeBase.join(left, right); - - // if (left.node !== null) { - // right.node = left.node; - // } - - return true; - } - - if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) { - let success = true; - if (!this.unify(left.paramType, right.paramType, enableDiagnostics)) { - success = false; - } - if (!this.unify(left.returnType, right.returnType, enableDiagnostics)) { - success = false; - } - if (success) { - TypeBase.join(left, right); - } - return success; - } - - if (left.kind === TypeKind.Tuple && right.kind === TypeKind.Tuple) { - if (left.elementTypes.length === right.elementTypes.length) { - let success = false; - const count = left.elementTypes.length; - for (let i = 0; i < count; i++) { - if (!this.unify(left.elementTypes[i], right.elementTypes[i], enableDiagnostics)) { - success = false; - } - } - if (success) { - TypeBase.join(left, right); - } - return success; - } - } - - if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) { - if (left.id === right.id) { - assert(left.argTypes.length === right.argTypes.length); - const count = left.argTypes.length; - let success = true; - for (let i = 0; i < count; i++) { - if (!this.unify(left.argTypes[i], right.argTypes[i], enableDiagnostics)) { - success = false; - } - } - if (success) { - TypeBase.join(left, right); - } - return success; - } - } - - if (left.kind === TypeKind.Nil && right.kind === TypeKind.Nil) { - return true; - } - - if (left.kind === TypeKind.Field && right.kind === TypeKind.Field) { - if (left.name === right.name) { - let success = true; - this.path.push(left.name); - if (!this.unifyField(left.type, right.type, enableDiagnostics)) { - success = false; - } - this.path.pop(); - if (!this.unify(left.restType, right.restType, enableDiagnostics)) { - success = false; - } - return success; - } - let success = true; - const newRestType = new TVar(this.nextTypeVarId++); - if (!this.unify(left.restType, new TField(right.name, right.type, newRestType), enableDiagnostics)) { - success = false; - } - if (!this.unify(right.restType, new TField(left.name, left.type, newRestType), enableDiagnostics)) { - success = false; - } - return success; - } - - if (left.kind === TypeKind.Nil && right.kind === TypeKind.Field) { - swap(); - } - - if (left.kind === TypeKind.Field && right.kind === TypeKind.Nil) { - let success = true; - this.path.push(left.name); - if (!this.unifyField(left.type, new TAbsent(right.node), enableDiagnostics)) { - success = false; - } - this.path.pop(); - if (!this.unify(left.restType, right, enableDiagnostics)) { - success = false; - } - return success - } - - if (left.kind === TypeKind.Nominal && right.kind === TypeKind.Nominal) { - if (left.decl === right.decl) { - return true; - } - // fall through to error reporting - } - - if (left.kind === TypeKind.App && right.kind === TypeKind.App) { - return this.unify(left.left, right.left, enableDiagnostics) - && this.unify(left.right, right.right, enableDiagnostics); - } - - if (enableDiagnostics) { - this.diagnostics.add( - new TypeMismatchDiagnostic( - left.substitute(this.solution), - right.substitute(this.solution), - [...this.constraint!.getNodes()], - this.path, - ) - ); - } - return false; - } - - public solve(constraint: Constraint): void { - - let queue = [ constraint ]; - let next = []; - let isNext = false; - - let errorCount = 0; - - for (;;) { - - if (queue.length === 0) { - if (next.length === 0) { - break; - } - isNext = true; - queue = next; - next = []; - } - - const constraint = queue.shift()!; - - sw: switch (constraint.kind) { - - case ConstraintKind.Many: - { - for (const element of constraint.elements) { - queue.push(element); - } - break; - } - -// case ConstraintKind.Class: -// { -// if (constraint.type.kind === TypeKind.Var) { -// if (isNext) { -// // TODO -// } else { -// next.push(constraint); -// } -// } else { -// const classDecl = this.lookupClass(constraint.className); -// if (classDecl === null) { -// this.diagnostics.add(new TypeclassNotFoundDiagnostic(constraint.className, constraint.node)); -// break; -// } -// for (const instance of classDecl.getInstances()) { -// if (this.unify(instance.inferredType, constraint.type, false)) { -// break sw; -// } -// } -// this.diagnostics.add(new TypeclassNotImplementedDiagnostic(constraint.className, constraint.type, constraint.node)); -// } -// break; -// } - - case ConstraintKind.Equal: - { - this.constraint = constraint; - if (!this.unify(constraint.left, constraint.right, true)) { - errorCount++; - if (errorCount === this.maxTypeErrorCount) { - return; - } - } - break; - } - - } - - } - - } - -}