Improve handling of polymorphic datatypes

This commit is contained in:
Sam Vervaeck 2022-09-11 15:23:22 +02:00
parent 988215cdb3
commit 85528ad8af
2 changed files with 320 additions and 187 deletions

View file

@ -22,8 +22,6 @@ import {
import { assert, isEmpty } from "./util";
import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl"
// FIXME Duplicate definitions are not checked
const MAX_TYPE_ERROR_COUNT = 5;
type NodeWithBindings = SourceFile | LetDeclaration;
@ -39,6 +37,7 @@ export enum TypeKind {
Labeled,
Record,
App,
Variant,
}
abstract class TypeBase {
@ -278,10 +277,9 @@ export class TRecord extends TypeBase {
public readonly kind = TypeKind.Record;
public nextRecord: TRecord | null = null;
public constructor(
public decl: Syntax,
public typeVars: TVar[],
public fields: Map<string, Type>,
public node: Syntax | null = null,
) {
@ -297,6 +295,7 @@ export class TRecord extends TypeBase {
public shallowClone(): TRecord {
return new TRecord(
this.decl,
this.typeVars,
this.fields,
this.node
);
@ -304,6 +303,15 @@ export class TRecord extends TypeBase {
public substitute(sub: TVSub): Type {
let changed = false;
const newTypeVars = [];
for (const typeVar of this.typeVars) {
const newTypeVar = typeVar.substitute(sub);
assert(newTypeVar.kind === TypeKind.Var);
if (newTypeVar !== typeVar) {
changed = true;
}
newTypeVars.push(newTypeVar);
}
const newFields = new Map();
for (const [key, type] of this.fields) {
const newType = type.substitute(sub);
@ -312,7 +320,7 @@ export class TRecord extends TypeBase {
}
newFields.set(key, newType);
}
return changed ? new TRecord(this.decl, newFields, this.node) : this;
return changed ? new TRecord(this.decl, newTypeVars, newFields, this.node) : this;
}
}
@ -329,7 +337,10 @@ export class TApp extends TypeBase {
super(node);
}
public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): TApp {
public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): Type {
if (argTypes.length === 0) {
return operatorType;
}
let count = argTypes.length;
let result = argTypes[count-1];
for (let i = count-2; i >= 0; i--) {
@ -379,6 +390,56 @@ export class TApp extends TypeBase {
}
export class TVariant extends TypeBase {
public readonly kind = TypeKind.Variant;
public constructor(
public typeVars: TVar[],
public elementTypes: Type[],
public node: Syntax | null = null,
) {
super(node);
}
public *getTypeVars(): Iterable<TVar> {
for (const elementType of this.elementTypes) {
yield* elementType.getTypeVars();
}
}
public shallowClone(): Type {
return new TVariant(
this.typeVars,
this.elementTypes,
this.node,
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newTypeVars = [];
for (const typeVar of this.typeVars) {
const newTypeVar = typeVar.substitute(sub);
assert(newTypeVar.kind === TypeKind.Var);
if (newTypeVar !== typeVar) {
changed = true;
}
newTypeVars.push(newTypeVar);
}
const newElementTypes = [];
for (const elementType of this.elementTypes) {
const newElementType = elementType.substitute(sub);
if (newElementType !== elementType) {
changed = true;
}
newElementTypes.push(newElementType);
}
return changed ? new TVariant(newTypeVars, newElementTypes, this.node) : this;
}
}
export type Type
= TCon
| TArrow
@ -387,6 +448,16 @@ export type Type
| TLabeled
| TRecord
| TApp
| TVariant
type KindedType
= TRecord
| TVariant
function isKindedType(type: Type): type is KindedType {
return type.kind === TypeKind.Variant
|| type.kind === TypeKind.Record;
}
class TVSet {
@ -910,7 +981,9 @@ export class Checker {
const declType = this.instantiate(scheme, node, sub);
const argTypes = [];
for (const typeVar of decl.tvs) {
argTypes.push(sub.get(typeVar)!);
const newTypeVar = sub.get(typeVar)!;
assert(newTypeVar.kind === TypeKind.Var);
argTypes.push(newTypeVar);
}
const fields = new Map();
for (const member of node.members) {
@ -937,10 +1010,10 @@ export class Checker {
throw new Error(`Unexpected ${member}`);
}
}
const type = TApp.build(new TRecord(decl, fields, node), argTypes, node);
const type = new TRecord(decl, argTypes, fields, node);
this.addConstraint(
new CEqual(
TApp.build(declType, argTypes, node),
declType,
type,
node,
)
@ -1377,7 +1450,7 @@ export class Checker {
}
}
this.popContext(context);
const type = new TRecord(node, fields, node);
const type = new TRecord(node, argTypes, fields, node);
const scheme = new Forall(typeVars, constraints, type);
parentEnv.add(node.name.text, scheme);
node.tvs = argTypes;
@ -1604,195 +1677,240 @@ export class Checker {
case ConstraintKind.Equal:
{
// constraint.dump();
if (!this.unify(constraint.left, constraint.right, solution, constraint)) {
const unify = (left: Type, right: Type): boolean => {
const resolveType = (type: Type): Type => {
while (type.kind === TypeKind.Var && solution.has(type)) {
type = solution.get(type)!;
}
return type;
}
const simplifyType = (type: Type): Type => {
type = resolveType(type);
if (type.kind === TypeKind.App) {
const stack = [];
let i = 0;
let operatorType: Type = type;
do {
operatorType = resolveType(operatorType.operatorType);
} while (operatorType.kind === TypeKind.App);
assert(isKindedType(operatorType));
let curr: Type = resolveType(type);
for (;;) {
while (curr.kind === TypeKind.App) {
stack.push(curr);
curr = resolveType(curr.operatorType);
}
if (curr !== operatorType) {
assert(i < operatorType.typeVars!.length);
unify(operatorType.typeVars![i++], curr);
}
if (stack.length === 0) {
break;
}
const next = stack.pop()!;
curr = resolveType(next.argType);
}
return operatorType;
}
return type;
}
left = simplifyType(left);
right = simplifyType(right);
if (left.kind === TypeKind.Var) {
if (right.hasTypeVar(left)) {
// TODO occurs check diagnostic
return false;
}
solution.set(left, right);
TypeBase.join(left, right);
return true;
}
if (right.kind === TypeKind.Var) {
return unify(right, left);
}
if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) {
if (left.paramTypes.length !== right.paramTypes.length) {
this.diagnostics.add(new ArityMismatchDiagnostic(left, right));
return false;
}
let success = true;
const count = left.paramTypes.length;
for (let i = 0; i < count; i++) {
if (!unify(left.paramTypes[i], right.paramTypes[i])) {
success = false;
}
}
if (!unify(left.returnType, right.returnType)) {
success = false;
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) {
return unify(left.returnType, right);
}
if (right.kind === TypeKind.Arrow) {
return unify(right, left);
}
if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) {
if (left.id === right.id) {
assert(left.argTypes.length === right.argTypes.length);
const count = left.argTypes.length;
let success = true;
for (let i = 0; i < count; i++) {
if (!unify(left.argTypes[i], right.argTypes[i])) {
success = false;
}
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
}
// if (left.kind === TypeKind.App && right.kind === TypeKind.App) {
// let leftElements = [...left.getSequence()];
// let rightElements = [...right.getSequence()];
// if (leftElements.length !== rightElements.length) {
// this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node));
// return false;
// }
// const count = leftElements.length;
// let success = true;
// for (let i = 0; i < count; i++) {
// if (!unify(leftElements[i], rightElements[i])) {
// success = false;
// }
// }
// return success;
// }
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) {
let success = false;
// This works like an ordinary union-find algorithm where an additional
// property 'fields' is carried over from the child nodes to the
// ever-changing root node.
const root = left.find();
right.parent = root;
if (root.fields === undefined) {
root.fields = new Map([ [ root.name, root.type ] ]);
}
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of right.fields) {
if (root.fields.has(fieldName)) {
if (!unify(root.fields.get(fieldName)!, fieldType)) {
success = false;
}
} else {
root.fields.set(fieldName, fieldType);
}
}
delete right.fields;
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) {
if (left.decl !== right.decl) {
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
return false;
}
let success = true;
const remaining = new Set(right.fields.keys());
for (const [fieldName, fieldType] of left.fields) {
if (right.fields.has(fieldName)) {
if (!unify(fieldType, right.fields.get(fieldName)!)) {
success = false;
}
remaining.delete(fieldName);
} else {
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName, constraint.node));
success = false;
}
}
for (const fieldName of remaining) {
this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName, constraint.node));
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
// while (left.kind === TypeKind.App) {
// left = left.operatorType;
// }
// while (right.kind === TypeKind.App) {
// right = right.operatorType;
// }
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let success = true;
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of right.fields) {
if (left.fields.has(fieldName)) {
if (!unify(fieldType, left.fields.get(fieldName)!)) {
success = false;
}
} else {
this.diagnostics.add(new FieldMissingDiagnostic(left, fieldName, constraint.node));
}
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Record) {
return unify(right, left);
}
this.diagnostics.add(
new UnificationFailedDiagnostic(
left.substitute(solution),
right.substitute(solution),
[...constraint.getNodes()],
)
);
return false;
}
if (!unify(constraint.left, constraint.right)) {
errorCount++;
if (errorCount === MAX_TYPE_ERROR_COUNT) {
return;
}
}
break;
}
}
}
}
private unify(left: Type, right: Type, solution: TVSub, constraint: CEqual): boolean {
while (left.kind === TypeKind.Var && solution.has(left)) {
left = solution.get(left)!;
}
while (right.kind === TypeKind.Var && solution.has(right)) {
right = solution.get(right)!;
}
if (left.kind === TypeKind.Var) {
if (right.hasTypeVar(left)) {
// TODO occurs check diagnostic
return false;
}
solution.set(left, right);
TypeBase.join(left, right);
return true;
}
if (right.kind === TypeKind.Var) {
return this.unify(right, left, solution, constraint);
}
if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) {
if (left.paramTypes.length !== right.paramTypes.length) {
this.diagnostics.add(new ArityMismatchDiagnostic(left, right));
return false;
}
let success = true;
const count = left.paramTypes.length;
for (let i = 0; i < count; i++) {
if (!this.unify(left.paramTypes[i], right.paramTypes[i], solution, constraint)) {
success = false;
}
}
if (!this.unify(left.returnType, right.returnType, solution, constraint)) {
success = false;
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) {
return this.unify(left.returnType, right, solution, constraint);
}
if (right.kind === TypeKind.Arrow) {
return this.unify(right, left, solution, constraint);
}
if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) {
if (left.id === right.id) {
assert(left.argTypes.length === right.argTypes.length);
const count = left.argTypes.length;
let success = true;
for (let i = 0; i < count; i++) {
if (!this.unify(left.argTypes[i], right.argTypes[i], solution, constraint)) {
success = false;
}
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
}
if (left.kind === TypeKind.App && right.kind === TypeKind.App) {
let leftElements = [...left.getSequence()];
let rightElements = [...right.getSequence()];
if (leftElements.length !== rightElements.length) {
this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node));
return false;
}
const count = leftElements.length;
let success = true;
for (let i = 0; i < count; i++) {
if (!this.unify(leftElements[i], rightElements[i], solution, constraint)) {
success = false;
}
}
return success;
}
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) {
let success = false;
// This works like an ordinary union-find algorithm where an additional
// property 'fields' is carried over from the child nodes to the
// ever-changing root node.
const root = left.find();
right.parent = root;
if (root.fields === undefined) {
root.fields = new Map([ [ root.name, root.type ] ]);
}
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of right.fields) {
if (root.fields.has(fieldName)) {
if (!this.unify(root.fields.get(fieldName)!, fieldType, solution, constraint)) {
success = false;
}
} else {
root.fields.set(fieldName, fieldType);
}
}
delete right.fields;
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) {
if (left.decl !== right.decl) {
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
return false;
}
let success = true;
const remaining = new Set(right.fields.keys());
for (const [fieldName, fieldType] of left.fields) {
if (right.fields.has(fieldName)) {
if (!this.unify(fieldType, right.fields.get(fieldName)!, solution, constraint)) {
success = false;
}
remaining.delete(fieldName);
} else {
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName, constraint.node));
success = false;
}
}
for (const fieldName of remaining) {
this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName, constraint.node));
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let success = true;
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of right.fields) {
if (left.fields.has(fieldName)) {
if (!this.unify(fieldType, left.fields.get(fieldName)!, solution, constraint)) {
success = false;
}
} else {
this.diagnostics.add(new FieldMissingDiagnostic(left, fieldName));
}
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Record) {
return this.unify(right, left, solution, constraint);
}
this.diagnostics.add(
new UnificationFailedDiagnostic(
left.substitute(solution),
right.substitute(solution),
[...constraint.getNodes()],
)
);
return false;
}
}

View file

@ -70,3 +70,18 @@ let is_odd x.
not (is_even True)
```
### Polymorphic records can be partially typed
```
struct Timestamped a b.
first: a
second: b
timestamp: Int
type Foo = Timestamped Int
type Bar = Foo Int
let t : Bar = Timestamped { first = "bar", second = 1, timestamp = 12345 }
```