From 85528ad8aff13e42d6963339fd1c5ca036f52857 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sun, 11 Sep 2022 15:23:22 +0200 Subject: [PATCH] Improve handling of polymorphic datatypes --- src/checker.ts | 492 +++++++++++++++++++++++-------------- src/test/type-inference.md | 15 ++ 2 files changed, 320 insertions(+), 187 deletions(-) diff --git a/src/checker.ts b/src/checker.ts index b3cf5108f..0e96415c1 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -22,8 +22,6 @@ import { import { assert, isEmpty } from "./util"; import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl" -// FIXME Duplicate definitions are not checked - const MAX_TYPE_ERROR_COUNT = 5; type NodeWithBindings = SourceFile | LetDeclaration; @@ -39,6 +37,7 @@ export enum TypeKind { Labeled, Record, App, + Variant, } abstract class TypeBase { @@ -278,10 +277,9 @@ export class TRecord extends TypeBase { public readonly kind = TypeKind.Record; - public nextRecord: TRecord | null = null; - public constructor( public decl: Syntax, + public typeVars: TVar[], public fields: Map, public node: Syntax | null = null, ) { @@ -297,6 +295,7 @@ export class TRecord extends TypeBase { public shallowClone(): TRecord { return new TRecord( this.decl, + this.typeVars, this.fields, this.node ); @@ -304,6 +303,15 @@ export class TRecord extends TypeBase { public substitute(sub: TVSub): Type { let changed = false; + const newTypeVars = []; + for (const typeVar of this.typeVars) { + const newTypeVar = typeVar.substitute(sub); + assert(newTypeVar.kind === TypeKind.Var); + if (newTypeVar !== typeVar) { + changed = true; + } + newTypeVars.push(newTypeVar); + } const newFields = new Map(); for (const [key, type] of this.fields) { const newType = type.substitute(sub); @@ -312,7 +320,7 @@ export class TRecord extends TypeBase { } newFields.set(key, newType); } - return changed ? new TRecord(this.decl, newFields, this.node) : this; + return changed ? new TRecord(this.decl, newTypeVars, newFields, this.node) : this; } } @@ -329,7 +337,10 @@ export class TApp extends TypeBase { super(node); } - public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): TApp { + public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): Type { + if (argTypes.length === 0) { + return operatorType; + } let count = argTypes.length; let result = argTypes[count-1]; for (let i = count-2; i >= 0; i--) { @@ -379,6 +390,56 @@ export class TApp extends TypeBase { } +export class TVariant extends TypeBase { + + public readonly kind = TypeKind.Variant; + + public constructor( + public typeVars: TVar[], + public elementTypes: Type[], + public node: Syntax | null = null, + ) { + super(node); + } + + public *getTypeVars(): Iterable { + for (const elementType of this.elementTypes) { + yield* elementType.getTypeVars(); + } + } + + public shallowClone(): Type { + return new TVariant( + this.typeVars, + this.elementTypes, + this.node, + ); + } + + public substitute(sub: TVSub): Type { + let changed = false; + const newTypeVars = []; + for (const typeVar of this.typeVars) { + const newTypeVar = typeVar.substitute(sub); + assert(newTypeVar.kind === TypeKind.Var); + if (newTypeVar !== typeVar) { + changed = true; + } + newTypeVars.push(newTypeVar); + } + const newElementTypes = []; + for (const elementType of this.elementTypes) { + const newElementType = elementType.substitute(sub); + if (newElementType !== elementType) { + changed = true; + } + newElementTypes.push(newElementType); + } + return changed ? new TVariant(newTypeVars, newElementTypes, this.node) : this; + } + +} + export type Type = TCon | TArrow @@ -387,6 +448,16 @@ export type Type | TLabeled | TRecord | TApp + | TVariant + +type KindedType + = TRecord + | TVariant + +function isKindedType(type: Type): type is KindedType { + return type.kind === TypeKind.Variant + || type.kind === TypeKind.Record; +} class TVSet { @@ -910,7 +981,9 @@ export class Checker { const declType = this.instantiate(scheme, node, sub); const argTypes = []; for (const typeVar of decl.tvs) { - argTypes.push(sub.get(typeVar)!); + const newTypeVar = sub.get(typeVar)!; + assert(newTypeVar.kind === TypeKind.Var); + argTypes.push(newTypeVar); } const fields = new Map(); for (const member of node.members) { @@ -937,10 +1010,10 @@ export class Checker { throw new Error(`Unexpected ${member}`); } } - const type = TApp.build(new TRecord(decl, fields, node), argTypes, node); + const type = new TRecord(decl, argTypes, fields, node); this.addConstraint( new CEqual( - TApp.build(declType, argTypes, node), + declType, type, node, ) @@ -1377,7 +1450,7 @@ export class Checker { } } this.popContext(context); - const type = new TRecord(node, fields, node); + const type = new TRecord(node, argTypes, fields, node); const scheme = new Forall(typeVars, constraints, type); parentEnv.add(node.name.text, scheme); node.tvs = argTypes; @@ -1604,195 +1677,240 @@ export class Checker { case ConstraintKind.Equal: { // constraint.dump(); - if (!this.unify(constraint.left, constraint.right, solution, constraint)) { + const unify = (left: Type, right: Type): boolean => { + + const resolveType = (type: Type): Type => { + while (type.kind === TypeKind.Var && solution.has(type)) { + type = solution.get(type)!; + } + return type; + } + + const simplifyType = (type: Type): Type => { + + type = resolveType(type); + + if (type.kind === TypeKind.App) { + const stack = []; + let i = 0; + let operatorType: Type = type; + do { + operatorType = resolveType(operatorType.operatorType); + } while (operatorType.kind === TypeKind.App); + assert(isKindedType(operatorType)); + let curr: Type = resolveType(type); + for (;;) { + while (curr.kind === TypeKind.App) { + stack.push(curr); + curr = resolveType(curr.operatorType); + } + if (curr !== operatorType) { + assert(i < operatorType.typeVars!.length); + unify(operatorType.typeVars![i++], curr); + } + if (stack.length === 0) { + break; + } + const next = stack.pop()!; + curr = resolveType(next.argType); + } + return operatorType; + } + return type; + } + + left = simplifyType(left); + right = simplifyType(right); + + if (left.kind === TypeKind.Var) { + if (right.hasTypeVar(left)) { + // TODO occurs check diagnostic + return false; + } + solution.set(left, right); + TypeBase.join(left, right); + return true; + } + + if (right.kind === TypeKind.Var) { + return unify(right, left); + } + + if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) { + if (left.paramTypes.length !== right.paramTypes.length) { + this.diagnostics.add(new ArityMismatchDiagnostic(left, right)); + return false; + } + let success = true; + const count = left.paramTypes.length; + for (let i = 0; i < count; i++) { + if (!unify(left.paramTypes[i], right.paramTypes[i])) { + success = false; + } + } + if (!unify(left.returnType, right.returnType)) { + success = false; + } + if (success) { + TypeBase.join(left, right); + } + return success; + } + + if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) { + return unify(left.returnType, right); + } + + if (right.kind === TypeKind.Arrow) { + return unify(right, left); + } + + if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) { + if (left.id === right.id) { + assert(left.argTypes.length === right.argTypes.length); + const count = left.argTypes.length; + let success = true; + for (let i = 0; i < count; i++) { + if (!unify(left.argTypes[i], right.argTypes[i])) { + success = false; + } + } + if (success) { + TypeBase.join(left, right); + } + return success; + } + } + + // if (left.kind === TypeKind.App && right.kind === TypeKind.App) { + // let leftElements = [...left.getSequence()]; + // let rightElements = [...right.getSequence()]; + // if (leftElements.length !== rightElements.length) { + // this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node)); + // return false; + // } + // const count = leftElements.length; + // let success = true; + // for (let i = 0; i < count; i++) { + // if (!unify(leftElements[i], rightElements[i])) { + // success = false; + // } + // } + // return success; + // } + + if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) { + let success = false; + // This works like an ordinary union-find algorithm where an additional + // property 'fields' is carried over from the child nodes to the + // ever-changing root node. + const root = left.find(); + right.parent = root; + if (root.fields === undefined) { + root.fields = new Map([ [ root.name, root.type ] ]); + } + if (right.fields === undefined) { + right.fields = new Map([ [ right.name, right.type ] ]); + } + for (const [fieldName, fieldType] of right.fields) { + if (root.fields.has(fieldName)) { + if (!unify(root.fields.get(fieldName)!, fieldType)) { + success = false; + } + } else { + root.fields.set(fieldName, fieldType); + } + } + delete right.fields; + if (success) { + TypeBase.join(left, right); + } + return success; + } + + if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) { + if (left.decl !== right.decl) { + this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()])); + return false; + } + let success = true; + const remaining = new Set(right.fields.keys()); + for (const [fieldName, fieldType] of left.fields) { + if (right.fields.has(fieldName)) { + if (!unify(fieldType, right.fields.get(fieldName)!)) { + success = false; + } + remaining.delete(fieldName); + } else { + this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName, constraint.node)); + success = false; + } + } + for (const fieldName of remaining) { + this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName, constraint.node)); + } + if (success) { + TypeBase.join(left, right); + } + return success; + } + + // while (left.kind === TypeKind.App) { + // left = left.operatorType; + // } + // while (right.kind === TypeKind.App) { + // right = right.operatorType; + // } + + if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) { + let success = true; + if (right.fields === undefined) { + right.fields = new Map([ [ right.name, right.type ] ]); + } + for (const [fieldName, fieldType] of right.fields) { + if (left.fields.has(fieldName)) { + if (!unify(fieldType, left.fields.get(fieldName)!)) { + success = false; + } + } else { + this.diagnostics.add(new FieldMissingDiagnostic(left, fieldName, constraint.node)); + } + } + if (success) { + TypeBase.join(left, right); + } + return success; + } + + if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Record) { + return unify(right, left); + } + + this.diagnostics.add( + new UnificationFailedDiagnostic( + left.substitute(solution), + right.substitute(solution), + [...constraint.getNodes()], + ) + ); + return false; + } + + if (!unify(constraint.left, constraint.right)) { errorCount++; if (errorCount === MAX_TYPE_ERROR_COUNT) { return; } } + break; } + } } } - private unify(left: Type, right: Type, solution: TVSub, constraint: CEqual): boolean { - - while (left.kind === TypeKind.Var && solution.has(left)) { - left = solution.get(left)!; - } - while (right.kind === TypeKind.Var && solution.has(right)) { - right = solution.get(right)!; - } - - if (left.kind === TypeKind.Var) { - if (right.hasTypeVar(left)) { - // TODO occurs check diagnostic - return false; - } - solution.set(left, right); - TypeBase.join(left, right); - return true; - } - - if (right.kind === TypeKind.Var) { - return this.unify(right, left, solution, constraint); - } - - if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) { - if (left.paramTypes.length !== right.paramTypes.length) { - this.diagnostics.add(new ArityMismatchDiagnostic(left, right)); - return false; - } - let success = true; - const count = left.paramTypes.length; - for (let i = 0; i < count; i++) { - if (!this.unify(left.paramTypes[i], right.paramTypes[i], solution, constraint)) { - success = false; - } - } - if (!this.unify(left.returnType, right.returnType, solution, constraint)) { - success = false; - } - if (success) { - TypeBase.join(left, right); - } - return success; - } - - if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) { - return this.unify(left.returnType, right, solution, constraint); - } - - if (right.kind === TypeKind.Arrow) { - return this.unify(right, left, solution, constraint); - } - - if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) { - if (left.id === right.id) { - assert(left.argTypes.length === right.argTypes.length); - const count = left.argTypes.length; - let success = true; - for (let i = 0; i < count; i++) { - if (!this.unify(left.argTypes[i], right.argTypes[i], solution, constraint)) { - success = false; - } - } - if (success) { - TypeBase.join(left, right); - } - return success; - } - } - - if (left.kind === TypeKind.App && right.kind === TypeKind.App) { - let leftElements = [...left.getSequence()]; - let rightElements = [...right.getSequence()]; - if (leftElements.length !== rightElements.length) { - this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node)); - return false; - } - const count = leftElements.length; - let success = true; - for (let i = 0; i < count; i++) { - if (!this.unify(leftElements[i], rightElements[i], solution, constraint)) { - success = false; - } - } - return success; - } - - if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) { - let success = false; - // This works like an ordinary union-find algorithm where an additional - // property 'fields' is carried over from the child nodes to the - // ever-changing root node. - const root = left.find(); - right.parent = root; - if (root.fields === undefined) { - root.fields = new Map([ [ root.name, root.type ] ]); - } - if (right.fields === undefined) { - right.fields = new Map([ [ right.name, right.type ] ]); - } - for (const [fieldName, fieldType] of right.fields) { - if (root.fields.has(fieldName)) { - if (!this.unify(root.fields.get(fieldName)!, fieldType, solution, constraint)) { - success = false; - } - } else { - root.fields.set(fieldName, fieldType); - } - } - delete right.fields; - if (success) { - TypeBase.join(left, right); - } - return success; - } - - if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) { - if (left.decl !== right.decl) { - this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()])); - return false; - } - let success = true; - const remaining = new Set(right.fields.keys()); - for (const [fieldName, fieldType] of left.fields) { - if (right.fields.has(fieldName)) { - if (!this.unify(fieldType, right.fields.get(fieldName)!, solution, constraint)) { - success = false; - } - remaining.delete(fieldName); - } else { - this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName, constraint.node)); - success = false; - } - } - for (const fieldName of remaining) { - this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName, constraint.node)); - } - if (success) { - TypeBase.join(left, right); - } - return success; - } - - if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) { - let success = true; - if (right.fields === undefined) { - right.fields = new Map([ [ right.name, right.type ] ]); - } - for (const [fieldName, fieldType] of right.fields) { - if (left.fields.has(fieldName)) { - if (!this.unify(fieldType, left.fields.get(fieldName)!, solution, constraint)) { - success = false; - } - } else { - this.diagnostics.add(new FieldMissingDiagnostic(left, fieldName)); - } - } - if (success) { - TypeBase.join(left, right); - } - return success; - } - - if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Record) { - return this.unify(right, left, solution, constraint); - } - - this.diagnostics.add( - new UnificationFailedDiagnostic( - left.substitute(solution), - right.substitute(solution), - [...constraint.getNodes()], - ) - ); - return false; - } - } diff --git a/src/test/type-inference.md b/src/test/type-inference.md index 37b2e0217..19b7c4cf7 100644 --- a/src/test/type-inference.md +++ b/src/test/type-inference.md @@ -70,3 +70,18 @@ let is_odd x. not (is_even True) ``` + +### Polymorphic records can be partially typed + +``` +struct Timestamped a b. + first: a + second: b + timestamp: Int + +type Foo = Timestamped Int + +type Bar = Foo Int + +let t : Bar = Timestamped { first = "bar", second = 1, timestamp = 12345 } +```