diff --git a/src/checker.ts b/src/checker.ts index e168a8f36..5705279a5 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -1,15 +1,14 @@ import { Expression, - LetDeclaration, Pattern, SourceFile, Syntax, SyntaxKind, TypeExpression } from "./cst"; -import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, Diagnostics, UnificationFailedDiagnostic } from "./diagnostics"; +import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, describeType, Diagnostics, UnificationFailedDiagnostic } from "./diagnostics"; import { assert } from "./util"; -import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl" +import { DirectedHashGraph, Graph, strongconnect } from "yagl" export enum TypeKind { Arrow, @@ -53,7 +52,9 @@ class TVar extends TypeBase { } public substitute(sub: TVSub): Type { - return sub.get(this) ?? this; + const other = sub.get(this); + return other === undefined + ? this : other.substitute(sub); } } @@ -269,6 +270,10 @@ class CEqual extends ConstraintBase { ); } + public dump(): void { + console.error(`${describeType(this.left)} ~ ${describeType(this.right)}`); + } + } class CMany extends ConstraintBase { @@ -323,7 +328,7 @@ export interface InferContext { typeVars: TVSet; env: TypeEnv; constraints: ConstraintSet; - returnType: Type; + returnType: Type | null; } export class Checker { @@ -331,17 +336,16 @@ export class Checker { private nextTypeVarId = 0; private nextConTypeId = 0; - private graph?: LabeledGraph; - private currentCycle?: Map; + //private graph?: Graph; + //private currentCycle?: Map; private stringType = new TCon(this.nextConTypeId++, [], 'String'); private intType = new TCon(this.nextConTypeId++, [], 'Int'); private boolType = new TCon(this.nextConTypeId++, [], 'Bool'); - private typeEnvs: TypeEnv[] = []; - private typeVars: TVSet[] = []; - private constraints: ConstraintSet[] = []; - private returnTypes: Type[] = []; + private contexts: InferContext[] = []; + + private solution = new TVSub(); public constructor( private diagnostics: Diagnostics @@ -363,7 +367,8 @@ export class Checker { private createTypeVar(): TVar { const typeVar = new TVar(this.nextTypeVarId++); - this.typeVars[this.typeVars.length-1].add(typeVar); + const context = this.contexts[this.contexts.length-1]; + context.typeVars.add(typeVar); return typeVar; } @@ -378,54 +383,33 @@ export class Checker { } case ConstraintKind.Equal: { - const count = this.constraints.length; - for (let i = count-1; i > 0; i--) { - const typeVars = this.typeVars[i]; - const constraints = this.constraints[i]; + const count = this.contexts.length; + let i; + for (i = count-1; i > 0; i--) { + const typeVars = this.contexts[i].typeVars; if (typeVars.intersectsType(constraint.left) || typeVars.intersectsType(constraint.right)) { - constraints.push(constraint); - return; + break; } } - this.constraints[0].push(constraint); - return; + this.contexts[i].constraints.push(constraint); + break; } } } private pushContext(context: InferContext) { - if (context.typeVars !== null) { - this.typeVars.push(context.typeVars); - } - if (context.env !== null) { - this.typeEnvs.push(context.env); - } - if (context.constraints !== null) { - this.constraints.push(context.constraints); - } - if (context.returnType !== null) { - this.returnTypes.push(context.returnType); - } + this.contexts.push(context); } private popContext(context: InferContext) { - if (context.typeVars !== null) { - this.typeVars.pop(); - } - if (context.env !== null) { - this.typeEnvs.pop(); - } - if (context.constraints !== null) { - this.constraints.pop(); - } - if (context.returnType !== null) { - this.returnTypes.pop(); - } + assert(this.contexts[this.contexts.length-1] === context); + this.contexts.pop(); } private lookup(name: string): Scheme | null { - for (let i = this.typeEnvs.length-1; i >= 0; i--) { - const scheme = this.typeEnvs[i].get(name); + for (let i = this.contexts.length-1; i >= 0; i--) { + const typeEnv = this.contexts[i].env; + const scheme = typeEnv.get(name); if (scheme !== undefined) { return scheme; } @@ -434,8 +418,9 @@ export class Checker { } private getReturnType(): Type { - assert(this.returnTypes.length > 0); - return this.returnTypes[this.returnTypes.length-1]; + const context = this.contexts[this.contexts.length-1]; + assert(context && context.returnType !== null); + return context.returnType; } private instantiate(scheme: Scheme): Type { @@ -445,13 +430,14 @@ export class Checker { } for (const constraint of scheme.constraints) { this.addConstraint(constraint.substitute(sub)); + // TODO keep record of a 'chain' of instantiations so that the diagnostics tool can output it on type error } return scheme.type.substitute(sub); } private addBinding(name: string, scheme: Scheme): void { - const env = this.typeEnvs[this.typeEnvs.length-1]; - env.set(name, scheme); + const context = this.contexts[this.contexts.length-1]; + context.env.set(name, scheme); } private forwardDeclare(node: Syntax): void { @@ -474,36 +460,7 @@ export class Checker { } case SyntaxKind.LetDeclaration: - { - const typeVars = new TVSet(); - const env = new TypeEnv(); - const constraints = new ConstraintSet(); - const returnType = this.createTypeVar(); - const context = { typeVars, env, constraints, returnType }; - node.context = context; - - this.pushContext(context); - - let type; - if (node.typeAssert !== null) { - type = this.inferTypeExpression(node.typeAssert.typeExpression); - } else { - type = this.createTypeVar(); - } - node.type = type; - - if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) { - for (const element of node.body.elements) { - this.forwardDeclare(element); - } - } - - this.popContext(context); - - this.inferBindings(node.pattern, type, context.typeVars, context.constraints); - break; - } } } @@ -564,56 +521,8 @@ export class Checker { } case SyntaxKind.LetDeclaration: - { - // Get the type that was stored on the node by forwardDeclare() - const type = node.type!; - const context = node.context!; - - this.pushContext(context); - - const paramTypes = []; - const returnType = context.returnType; - for (const param of node.params) { - const paramType = this.createTypeVar() - this.inferBindings(param.pattern, paramType, [], []); - paramTypes.push(paramType); - } - - if (node.body !== null) { - switch (node.body.kind) { - case SyntaxKind.ExprBody: - { - this.addConstraint( - new CEqual( - this.inferExpression(node.body.expression), - returnType, - node.body.expression - ) - ); - break; - } - case SyntaxKind.BlockBody: - { - for (const element of node.body.elements) { - this.infer(element); - } - break; - } - } - } - - this.addConstraint(new CEqual(type, new TArrow(paramTypes, returnType), node)); - - this.popContext(context); - - // FIXME these two may need to go below inferBindings - //this.typeVars.pop(); - //this.constraints.pop(); - break; - } - default: throw new Error(`Unexpected ${node}`); @@ -631,22 +540,17 @@ export class Checker { case SyntaxKind.ReferenceExpression: { assert(node.name.modulePath.length === 0); - const target = node.getScope().lookup(node.name.name.text) as LetDeclaration; - if (target === node.getScope().node) { - return target.type!; - } - const targetType = this.currentCycle.get(target); - if (targetType) { - return targetType; + const scope = node.getScope(); + const target = scope.lookup(node.name.name.text); + if (target !== null && target.type !== undefined) { + return target.type; } const scheme = this.lookup(node.name.name.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name)); return new TAny(); } - const type = this.instantiate(scheme); - this.currentCycle.set(target, type); - return type; + return this.instantiate(scheme); } case SyntaxKind.CallExpression: @@ -760,8 +664,8 @@ export class Checker { } - private computeReferenceGraph(node: SourceFile): LabeledGraph { - const graph = new LabeledDirectedHashGraph(); + private computeReferenceGraph(node: SourceFile): Graph { + const graph = new DirectedHashGraph(); const visit = (node: Syntax, source: Syntax | null) => { switch (node.kind) { case SyntaxKind.ConstantExpression: @@ -777,9 +681,15 @@ export class Checker { { // TODO only add references to nodes on the same level assert(node.name.modulePath.length === 0); - const target = node.getScope().lookup(node.name.name.text); + let target = node.getScope().lookup(node.name.name.text); + if (target !== null && target.kind === SyntaxKind.Param) { + target = target.parent!; + if (source !== null) { + graph.addEdge(target, source); + } + } if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) { - graph.addEdge(source, target, node); + graph.addEdge(source, target); } break; } @@ -864,19 +774,14 @@ export class Checker { public check(node: SourceFile): void { - this.graph = this.computeReferenceGraph(node); - const typeVars = new TVSet(); const constraints = new ConstraintSet(); const env = new TypeEnv(); + const context: InferContext = { typeVars, constraints, env, returnType: null }; - this.typeVars.push(typeVars); - this.constraints.push(constraints); - this.typeEnvs.push(env); + this.pushContext(context); const a = this.createTypeVar(); - const b = this.createTypeVar(); - const d = this.createTypeVar(); env.set('String', new Forall([], [], this.stringType)); env.set('Int', new Forall([], [], this.intType)); @@ -889,38 +794,113 @@ export class Checker { env.set('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType))); env.set('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType))); - //this.infer(node); - for (const node of this.graph.getVertices()) { - this.forwardDeclare(node); - } - for (const nodes of strongconnect(this.graph)) { - this.currentCycle = new Map(); + const graph = this.computeReferenceGraph(node); + + for (const nodes of strongconnect(graph)) { + + const typeVars = new TVSet(); + const constraints = new ConstraintSet(); + for (const node of nodes) { - for (const node of nodes) { - this.currentCycle.set(node, null); + + assert(node.kind === SyntaxKind.LetDeclaration); + + const env = new TypeEnv(); + const context: InferContext = { + typeVars, + constraints, + env, + returnType: null, + }; + node.context = context; + + this.contexts.push(context); + + const returnType = this.createTypeVar(); + context.returnType = returnType; + + const paramTypes = []; + for (const param of node.params) { + const paramType = this.createTypeVar() + this.inferBindings(param.pattern, paramType, [], []); + paramTypes.push(paramType); } - this.infer(node); + + let type = new TArrow(paramTypes, returnType); + if (node.typeAssert !== null) { + this.addConstraint( + new CEqual( + this.inferTypeExpression(node.typeAssert.typeExpression), + type, + node.typeAssert + ) + ); + } + node.type = type; + + this.contexts.pop(); + + this.inferBindings(node.pattern, type, typeVars, constraints); } + + for (const node of nodes) { + + assert(node.kind === SyntaxKind.LetDeclaration); + + const context = node.context!; + const returnType = context.returnType!; + this.contexts.push(context); + + if (node.body !== null) { + switch (node.body.kind) { + case SyntaxKind.ExprBody: + { + this.addConstraint( + new CEqual( + this.inferExpression(node.body.expression), + returnType, + node.body.expression + ) + ); + break; + } + case SyntaxKind.BlockBody: + { + for (const element of node.body.elements) { + this.infer(element); + } + break; + } + } + } + + this.contexts.pop(); + } + + for (const node of nodes) { + assert(node.kind === SyntaxKind.LetDeclaration); + delete node.type; + } + } - this.currentCycle = new Map(); + for (const element of node.elements) { if (element.kind !== SyntaxKind.LetDeclaration) { - //this.forwardDeclare(element); this.infer(element); } } - this.typeVars.pop(); - this.constraints.pop(); - this.typeEnvs.pop(); + //this.forwardDeclare(node); + //this.infer(node); - this.solve(new CMany(constraints)); + this.popContext(context); + + this.solve(new CMany(constraints), this.solution); } - private solve(constraint: Constraint): TVSub { + private solve(constraint: Constraint, solution: TVSub): void { const queue = [ constraint ]; - const solution = new TVSub(); while (queue.length > 0) { @@ -953,8 +933,6 @@ export class Checker { } - return solution; - } private unify(left: Type, right: Type, solution: TVSub): boolean { diff --git a/src/cst.ts b/src/cst.ts index c7c9b167c..36150ef89 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -172,6 +172,7 @@ export type Syntax | Param | Body | StructDeclarationField + | TypeAssert | Declaration | Statement | Expression diff --git a/src/diagnostics.ts b/src/diagnostics.ts index 0abbd4f45..d3985c10f 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -145,7 +145,7 @@ export class BindingNotFoudDiagnostic { } -function describeType(type: Type): string { +export function describeType(type: Type): string { switch (type.kind) { case TypeKind.Any: return 'Any';