import { Expression, LetDeclaration, Pattern, SourceFile, Syntax, SyntaxKind, TypeExpression } from "./cst"; import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, Diagnostics, UnificationFailedDiagnostic } from "./diagnostics"; import { assert } from "./util"; import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl" export enum TypeKind { Arrow, Var, Con, Any, Tuple, } abstract class TypeBase { public abstract readonly kind: TypeKind; public abstract getTypeVars(): Iterable; public abstract substitute(sub: TVSub): Type; public hasTypeVar(tv: TVar): boolean { for (const other of this.getTypeVars()) { if (tv.id === other.id) { return true; } } return false; } } class TVar extends TypeBase { public readonly kind = TypeKind.Var; public constructor( public id: number, ) { super(); } public *getTypeVars(): Iterable { yield this; } public substitute(sub: TVSub): Type { return sub.get(this) ?? this; } } export class TArrow extends TypeBase { public readonly kind = TypeKind.Arrow; public constructor( public paramTypes: Type[], public returnType: Type, ) { super(); } public *getTypeVars(): Iterable { for (const paramType of this.paramTypes) { yield* paramType.getTypeVars(); } yield* this.returnType.getTypeVars(); } public substitute(sub: TVSub): Type { let changed = false; const newParamTypes = []; for (const paramType of this.paramTypes) { const newParamType = paramType.substitute(sub); if (newParamType !== paramType) { changed = true; } newParamTypes.push(newParamType); } const newReturnType = this.returnType.substitute(sub); if (newReturnType !== this.returnType) { changed = true; } return changed ? new TArrow(newParamTypes, newReturnType) : this; } } class TCon extends TypeBase { public readonly kind = TypeKind.Con; public constructor( public id: number, public argTypes: Type[], public displayName: string, ) { super(); } public *getTypeVars(): Iterable { for (const argType of this.argTypes) { yield* argType.getTypeVars(); } } public substitute(sub: TVSub): Type { let changed = false; const newArgTypes = []; for (const argType of this.argTypes) { const newArgType = argType.substitute(sub); if (newArgType !== argType) { changed = true; } newArgTypes.push(newArgType); } return changed ? new TCon(this.id, newArgTypes, this.displayName) : this; } } class TAny extends TypeBase { public readonly kind = TypeKind.Any; public *getTypeVars(): Iterable { } public substitute(sub: TVSub): Type { return this; } } class TTuple extends TypeBase { public readonly kind = TypeKind.Tuple; public constructor( public elementTypes: Type[], ) { super(); } public *getTypeVars(): Iterable { for (const elementType of this.elementTypes) { yield* elementType.getTypeVars(); } } public substitute(sub: TVSub): Type { let changed = false; const newElementTypes = []; for (const elementType of this.elementTypes) { const newElementType = elementType.substitute(sub); if (newElementType !== elementType) { changed = true; } newElementTypes.push(newElementType); } return changed ? new TTuple(newElementTypes) : this; } } export type Type = TCon | TArrow | TVar | TAny | TTuple class TVSet { private mapping = new Map(); public add(tv: TVar): void { this.mapping.set(tv.id, tv); } public has(tv: TVar): boolean { return this.mapping.has(tv.id); } public intersectsType(type: Type): boolean { for (const tv of type.getTypeVars()) { if (this.has(tv)) { return true; } } return false; } public delete(tv: TVar): void { this.mapping.delete(tv.id); } public [Symbol.iterator](): Iterator { return this.mapping.values(); } } class TVSub { private mapping = new Map(); public set(tv: TVar, type: Type): void { this.mapping.set(tv.id, type); } public get(tv: TVar): Type | undefined { return this.mapping.get(tv.id); } public has(tv: TVar): boolean { return this.mapping.has(tv.id); } public delete(tv: TVar): void { this.mapping.delete(tv.id); } public values(): Iterable { return this.mapping.values(); } } const enum ConstraintKind { Equal, Many, } abstract class ConstraintBase { public abstract substitute(sub: TVSub): Constraint; } class CEqual extends ConstraintBase { public readonly kind = ConstraintKind.Equal; public constructor( public left: Type, public right: Type, public node: Syntax, ) { super(); } public substitute(sub: TVSub): Constraint { return new CEqual( this.left.substitute(sub), this.right.substitute(sub), this.node, ); } } class CMany extends ConstraintBase { public readonly kind = ConstraintKind.Many; public constructor( public elements: Constraint[] ) { super(); } public substitute(sub: TVSub): Constraint { const newElements = []; for (const element of this.elements) { newElements.push(element.substitute(sub)); } return new CMany(newElements); } } type Constraint = CEqual | CMany class ConstraintSet extends Array { } abstract class SchemeBase { } class Forall extends SchemeBase { public constructor( public tvs: TVar[], public constraints: Constraint[], public type: Type, ) { super(); } } type Scheme = Forall class TypeEnv extends Map { } export interface InferContext { typeVars: TVSet; env: TypeEnv; constraints: ConstraintSet; returnType: Type; } export class Checker { private nextTypeVarId = 0; private nextConTypeId = 0; private graph?: LabeledGraph; private currentCycle?: Map; private stringType = new TCon(this.nextConTypeId++, [], 'String'); private intType = new TCon(this.nextConTypeId++, [], 'Int'); private boolType = new TCon(this.nextConTypeId++, [], 'Bool'); private typeEnvs: TypeEnv[] = []; private typeVars: TVSet[] = []; private constraints: ConstraintSet[] = []; private returnTypes: Type[] = []; public constructor( private diagnostics: Diagnostics ) { } public getIntType(): Type { return this.intType; } public getStringType(): Type { return this.stringType; } public getBoolType(): Type { return this.boolType; } private createTypeVar(): TVar { const typeVar = new TVar(this.nextTypeVarId++); this.typeVars[this.typeVars.length-1].add(typeVar); return typeVar; } private addConstraint(constraint: Constraint): void { switch (constraint.kind) { case ConstraintKind.Many: { for (const element of constraint.elements) { this.addConstraint(element); } return; } case ConstraintKind.Equal: { const count = this.constraints.length; for (let i = count-1; i > 0; i--) { const typeVars = this.typeVars[i]; const constraints = this.constraints[i]; if (typeVars.intersectsType(constraint.left) || typeVars.intersectsType(constraint.right)) { constraints.push(constraint); return; } } this.constraints[0].push(constraint); return; } } } private pushContext(context: InferContext) { if (context.typeVars !== null) { this.typeVars.push(context.typeVars); } if (context.env !== null) { this.typeEnvs.push(context.env); } if (context.constraints !== null) { this.constraints.push(context.constraints); } if (context.returnType !== null) { this.returnTypes.push(context.returnType); } } private popContext(context: InferContext) { if (context.typeVars !== null) { this.typeVars.pop(); } if (context.env !== null) { this.typeEnvs.pop(); } if (context.constraints !== null) { this.constraints.pop(); } if (context.returnType !== null) { this.returnTypes.pop(); } } private lookup(name: string): Scheme | null { for (let i = this.typeEnvs.length-1; i >= 0; i--) { const scheme = this.typeEnvs[i].get(name); if (scheme !== undefined) { return scheme; } } return null; } private getReturnType(): Type { assert(this.returnTypes.length > 0); return this.returnTypes[this.returnTypes.length-1]; } private instantiate(scheme: Scheme): Type { const sub = new TVSub(); for (const tv of scheme.tvs) { sub.set(tv, this.createTypeVar()); } for (const constraint of scheme.constraints) { this.addConstraint(constraint.substitute(sub)); } return scheme.type.substitute(sub); } private addBinding(name: string, scheme: Scheme): void { const env = this.typeEnvs[this.typeEnvs.length-1]; env.set(name, scheme); } private forwardDeclare(node: Syntax): void { switch (node.kind) { case SyntaxKind.SourceFile: { for (const element of node.elements) { this.forwardDeclare(element); } break; } case SyntaxKind.ExpressionStatement: case SyntaxKind.ReturnStatement: { // TODO This should be updated if block-scoped expressions are allowed. break; } case SyntaxKind.LetDeclaration: { const typeVars = new TVSet(); const env = new TypeEnv(); const constraints = new ConstraintSet(); const returnType = this.createTypeVar(); const context = { typeVars, env, constraints, returnType }; node.context = context; this.pushContext(context); let type; if (node.typeAssert !== null) { type = this.inferTypeExpression(node.typeAssert.typeExpression); } else { type = this.createTypeVar(); } node.type = type; if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) { for (const element of node.body.elements) { this.forwardDeclare(element); } } this.popContext(context); this.inferBindings(node.pattern, type, context.typeVars, context.constraints); break; } } } public infer(node: Syntax): void { switch (node.kind) { case SyntaxKind.SourceFile: { for (const element of node.elements) { this.infer(element); } break; } case SyntaxKind.ExpressionStatement: { this.inferExpression(node.expression); break; } case SyntaxKind.IfStatement: { for (const cs of node.cases) { if (cs.test !== null) { this.addConstraint( new CEqual( this.inferExpression(cs.test), this.getBoolType(), cs.test ) ); } for (const element of cs.elements) { this.infer(element); } } break; } case SyntaxKind.ReturnStatement: { let type; if (node.expression === null) { type = new TTuple([]); } else { type = this.inferExpression(node.expression); } this.addConstraint( new CEqual( this.getReturnType(), type, node ) ); break; } case SyntaxKind.LetDeclaration: { // Get the type that was stored on the node by forwardDeclare() const type = node.type!; const context = node.context!; this.pushContext(context); const paramTypes = []; const returnType = context.returnType; for (const param of node.params) { const paramType = this.createTypeVar() this.inferBindings(param.pattern, paramType, [], []); paramTypes.push(paramType); } if (node.body !== null) { switch (node.body.kind) { case SyntaxKind.ExprBody: { this.addConstraint( new CEqual( this.inferExpression(node.body.expression), returnType, node.body.expression ) ); break; } case SyntaxKind.BlockBody: { for (const element of node.body.elements) { this.infer(element); } break; } } } this.addConstraint(new CEqual(type, new TArrow(paramTypes, returnType), node)); this.popContext(context); // FIXME these two may need to go below inferBindings //this.typeVars.pop(); //this.constraints.pop(); break; } default: throw new Error(`Unexpected ${node}`); } } public inferExpression(node: Expression): Type { switch (node.kind) { case SyntaxKind.NestedExpression: return this.inferExpression(node.expression); case SyntaxKind.ReferenceExpression: { assert(node.name.modulePath.length === 0); const target = node.getScope().lookup(node.name.name.text) as LetDeclaration; if (target === node.getScope().node) { return target.type!; } const targetType = this.currentCycle.get(target); if (targetType) { return targetType; } const scheme = this.lookup(node.name.name.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name)); return new TAny(); } const type = this.instantiate(scheme); this.currentCycle.set(target, type); return type; } case SyntaxKind.CallExpression: { const opType = this.inferExpression(node.func); const retType = this.createTypeVar(); const paramTypes = []; for (const arg of node.args) { paramTypes.push(this.inferExpression(arg)); } this.addConstraint( new CEqual( opType, new TArrow(paramTypes, retType), node ) ); return retType; } case SyntaxKind.ConstantExpression: { let ty; switch (node.token.kind) { case SyntaxKind.StringLiteral: ty = this.getStringType(); break; case SyntaxKind.Integer: ty = this.getIntType(); break; } return ty; } case SyntaxKind.NamedTupleExpression: { const scheme = this.lookup(node.name.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); return new TAny(); } const type = this.instantiate(scheme); assert(type.kind === TypeKind.Con); const argTypes = []; for (const element of node.elements) { argTypes.push(this.inferExpression(element)); } return new TCon(type.id, argTypes, type.displayName); } case SyntaxKind.InfixExpression: { const scheme = this.lookup(node.operator.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.operator.text, node.operator)); return new TAny(); } const opType = this.instantiate(scheme); const retType = this.createTypeVar(); const leftType = this.inferExpression(node.left); const rightType = this.inferExpression(node.right); this.addConstraint( new CEqual( new TArrow([ leftType, rightType ], retType), opType, node, ), ); return retType; } default: throw new Error(`Unexpected ${node.constructor.name}`); } } public inferTypeExpression(node: TypeExpression): Type { switch (node.kind) { case SyntaxKind.ReferenceTypeExpression: { const scheme = this.lookup(node.name.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); return new TAny(); } return this.instantiate(scheme); } default: throw new Error(`Unrecognised ${node}`); } } public inferBindings(pattern: Pattern, type: Type, tvs: TVar[], constraints: Constraint[]): void { switch (pattern.kind) { case SyntaxKind.BindPattern: { this.addBinding(pattern.name.text, new Forall(tvs, constraints, type)); break; } } } private computeReferenceGraph(node: SourceFile): LabeledGraph { const graph = new LabeledDirectedHashGraph(); const visit = (node: Syntax, source: Syntax | null) => { switch (node.kind) { case SyntaxKind.ConstantExpression: break; case SyntaxKind.SourceFile: { for (const element of node.elements) { visit(element, source); } break; } case SyntaxKind.ReferenceExpression: { // TODO only add references to nodes on the same level assert(node.name.modulePath.length === 0); const target = node.getScope().lookup(node.name.name.text); if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) { graph.addEdge(source, target, node); } break; } case SyntaxKind.NamedTupleExpression: { for (const arg of node.elements) { visit(arg, source); } break; } case SyntaxKind.NestedExpression: { visit(node.expression, source); break; } case SyntaxKind.InfixExpression: { visit(node.left, source); visit(node.right, source); break; } case SyntaxKind.CallExpression: { visit(node.func, source); for (const arg of node.args) { visit(arg, source); } break; } case SyntaxKind.IfStatement: { for (const cs of node.cases) { if (cs.test !== null) { visit(cs.test, source); } for (const element of cs.elements) { visit(element, source); } } break; } case SyntaxKind.ExpressionStatement: { visit(node.expression, source); break; } case SyntaxKind.ReturnStatement: { if (node.expression !== null) { visit(node.expression, source); } break; } case SyntaxKind.LetDeclaration: { graph.addVertex(node); if (node.body !== null) { switch (node.body.kind) { case SyntaxKind.ExprBody: { visit(node.body.expression, node); break; } case SyntaxKind.BlockBody: { for (const element of node.body.elements) { visit(element, node); } break; } } } break; } default: throw new Error(`Unexpected ${node.constructor.name}`); } } visit(node, null); return graph; } public check(node: SourceFile): void { this.graph = this.computeReferenceGraph(node); const typeVars = new TVSet(); const constraints = new ConstraintSet(); const env = new TypeEnv(); this.typeVars.push(typeVars); this.constraints.push(constraints); this.typeEnvs.push(env); const a = this.createTypeVar(); const b = this.createTypeVar(); const d = this.createTypeVar(); env.set('String', new Forall([], [], this.stringType)); env.set('Int', new Forall([], [], this.intType)); env.set('True', new Forall([], [], this.boolType)); env.set('False', new Forall([], [], this.boolType)); env.set('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); env.set('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); env.set('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); env.set('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); env.set('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType))); env.set('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType))); //this.infer(node); for (const node of this.graph.getVertices()) { this.forwardDeclare(node); } for (const nodes of strongconnect(this.graph)) { this.currentCycle = new Map(); for (const node of nodes) { for (const node of nodes) { this.currentCycle.set(node, null); } this.infer(node); } } this.currentCycle = new Map(); for (const element of node.elements) { if (element.kind !== SyntaxKind.LetDeclaration) { //this.forwardDeclare(element); this.infer(element); } } this.typeVars.pop(); this.constraints.pop(); this.typeEnvs.pop(); this.solve(new CMany(constraints)); } private solve(constraint: Constraint): TVSub { const queue = [ constraint ]; const solution = new TVSub(); while (queue.length > 0) { const constraint = queue.pop()!; switch (constraint.kind) { case ConstraintKind.Many: { for (const element of constraint.elements) { queue.push(element); } break; } case ConstraintKind.Equal: { if (!this.unify(constraint.left, constraint.right, solution)) { this.diagnostics.add( new UnificationFailedDiagnostic( constraint.left.substitute(solution), constraint.right.substitute(solution), constraint.node ) ); } break; } } } return solution; } private unify(left: Type, right: Type, solution: TVSub): boolean { if (left.kind === TypeKind.Var && solution.has(left)) { left = solution.get(left)!; } if (right.kind === TypeKind.Var && solution.has(right)) { right = solution.get(right)!; } if (left.kind === TypeKind.Var) { if (right.hasTypeVar(left)) { // TODO occurs check diagnostic return false; } solution.set(left, right); return true; } if (right.kind === TypeKind.Var) { return this.unify(right, left, solution); } if (left.kind === TypeKind.Any || right.kind === TypeKind.Any) { return true; } if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) { if (left.paramTypes.length !== right.paramTypes.length) { this.diagnostics.add(new ArityMismatchDiagnostic(left, right)); return false; } let success = true; const count = left.paramTypes.length; for (let i = 0; i < count; i++) { if (!this.unify(left.paramTypes[i], right.paramTypes[i], solution)) { success = false; } } if (!this.unify(left.returnType, right.returnType, solution)) { success = false; } return success; } if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) { return this.unify(left.returnType, right, solution); } if (right.kind === TypeKind.Arrow) { return this.unify(right, left, solution); } if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) { if (left.id !== right.id) { return false; } assert(left.argTypes.length === right.argTypes.length); const count = left.argTypes.length; for (let i = 0; i < count; i++) { if (!this.unify(left.argTypes[i], right.argTypes[i], solution)) { return false; } } return true; } return false; } }