Fix rigid type vars not instantiating correctly by introducing

union-find
This commit is contained in:
Sam Vervaeck 2023-06-28 22:09:17 +02:00
parent 8d2f3c4977
commit d194ff9b2e
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
2 changed files with 135 additions and 11 deletions

View file

@ -1,7 +1,4 @@
// FIXME Something wrong with eager solving unifying a3 ~ a and therefore removing polymorphism
// TODO Add simplifyType() in instantiate() to fix this
import { import {
ClassDeclaration, ClassDeclaration,
Expression, Expression,
@ -680,6 +677,78 @@ export class Checker {
return sub; return sub;
} }
private simplifyType(type: Type): Type {
type = type.find();
switch (type.kind) {
case TypeKind.UniVar:
case TypeKind.RigidVar:
case TypeKind.Nil:
case TypeKind.Absent:
case TypeKind.Nominal:
case TypeKind.Con:
return type;
case TypeKind.TupleIndex:
{
const tupleType = this.simplifyType(type.tupleType);
if (tupleType.kind === TypeKind.Tuple) {
// TODO check bounds and add diagnostic
const newType = tupleType.elementTypes[type.index];
type.set(newType);
return newType;
}
return type;
}
case TypeKind.App:
{
const left = type.left.find();
const right = type.right.find();
if (left === type.left && right === type.right) {
return type;
}
return new TApp(left, right, type.node);
}
case TypeKind.Arrow:
{
const paramType = type.paramType.find();
const returnType = type.returnType.find();
if (paramType === type.paramType && returnType === type.returnType) {
return type;
}
return new TArrow(paramType, returnType, type.node);
}
case TypeKind.Field:
{
const newType = type.type.find();
const newRestType = type.restType.find();
if (newType === type.type && newRestType === type.restType) {
return type;
}
return new TField(type.name, newType, newRestType, type.node);
}
case TypeKind.Present:
{
const newType = type.type.find();
if (newType === type.type) {
return type;
}
return new TPresent(newType, type.node);
}
case TypeKind.Tuple:
{
let changed = false;
const newElementTypes = [];
for (const elementType of type.elementTypes) {
const newElementType = elementType.find();
newElementTypes.push(newElementType);
if (newElementType !== elementType) {
changed = true;
}
}
return changed ? new TTuple(newElementTypes, type.node) : type;
}
}
}
private instantiate(scheme: Scheme, node: Syntax | null, sub = this.createSubstitution(scheme)): Type { private instantiate(scheme: Scheme, node: Syntax | null, sub = this.createSubstitution(scheme)): Type {
const transform = (constraint: Constraint): Constraint => { const transform = (constraint: Constraint): Constraint => {
switch (constraint.kind) { switch (constraint.kind) {
@ -691,8 +760,9 @@ export class Checker {
return new CMany(newConstraints); return new CMany(newConstraints);
case ConstraintKind.Empty: case ConstraintKind.Empty:
return constraint; return constraint;
case ConstraintKind.Class:
case ConstraintKind.Equal: case ConstraintKind.Equal:
constraint.left = this.simplifyType(constraint.left)
constraint.right = this.simplifyType(constraint.right)
const newConstraint = constraint.substitute(sub); const newConstraint = constraint.substitute(sub);
newConstraint.node = node; newConstraint.node = node;
newConstraint.prevInstantiation = constraint; newConstraint.prevInstantiation = constraint;
@ -702,7 +772,7 @@ export class Checker {
} }
} }
this.addConstraint(transform(scheme.constraint)); this.addConstraint(transform(scheme.constraint));
return scheme.type.substitute(sub); return this.simplifyType(scheme.type).substitute(sub);
} }
private addBinding(name: string, scheme: Scheme, kind: Symkind): void { private addBinding(name: string, scheme: Scheme, kind: Symkind): void {
@ -1967,10 +2037,10 @@ export class Checker {
private unify(left: Type, right: Type, enableDiagnostics: boolean): boolean { private unify(left: Type, right: Type, enableDiagnostics: boolean): boolean {
console.log(`unify ${describeType(left)} @ ${left.node && left.node.constructor && left.node.constructor.name} ~ ${describeType(right)} @ ${right.node && right.node.constructor && right.node.constructor.name}`); // console.log(`unify ${describeType(left)} @ ${left.node && left.node.constructor && left.node.constructor.name} ~ ${describeType(right)} @ ${right.node && right.node.constructor && right.node.constructor.name}`);
left = this.find(left); left = this.simplifyType(left);
right = this.find(right); right = this.simplifyType(right);
const swap = () => { [right, left] = [left, right]; } const swap = () => { [right, left] = [left, right]; }
@ -2024,7 +2094,7 @@ export class Checker {
//propagateClasses(left.context, right); //propagateClasses(left.context, right);
// We are all clear; set the actual type of left to right. // We are all clear; set the actual type of left to right.
this.solution.set(left, right); left.set(right);
// This is a very specific adjustment that is critical to the // This is a very specific adjustment that is critical to the
// well-functioning of the infer/unify algorithm. When addConstraint() is // well-functioning of the infer/unify algorithm. When addConstraint() is
@ -2159,8 +2229,8 @@ export class Checker {
if (enableDiagnostics) { if (enableDiagnostics) {
this.diagnostics.add( this.diagnostics.add(
new TypeMismatchDiagnostic( new TypeMismatchDiagnostic(
left.substitute(this.solution), this.simplifyType(left),
right.substitute(this.solution), this.simplifyType(right),
[...this.constraint!.getNodes()], [...this.constraint!.getNodes()],
this.path, this.path,
) )

View file

@ -8,6 +8,7 @@ export enum TypeKind {
RigidVar, RigidVar,
Con, Con,
Tuple, Tuple,
TupleIndex,
App, App,
Nominal, Nominal,
Field, Field,
@ -20,6 +21,8 @@ export abstract class TypeBase {
public abstract readonly kind: TypeKind; public abstract readonly kind: TypeKind;
public parent: Type = this as unknown as Type;
public next: Type = this as any; public next: Type = this as any;
public abstract node: Syntax | null; public abstract node: Syntax | null;
@ -36,6 +39,19 @@ export abstract class TypeBase {
public abstract substitute(sub: TVSub): Type; public abstract substitute(sub: TVSub): Type;
public find(): Type {
let curr = this as unknown as Type;
while (curr.parent !== curr) {
curr.parent = curr.parent.parent;
curr = curr.parent;
}
return curr;
}
public set(newType: Type): void {
this.find().parent = newType;
}
public hasTypeVar(tv: TUniVar): boolean { public hasTypeVar(tv: TUniVar): boolean {
for (const other of this.getTypeVars()) { for (const other of this.getTypeVars()) {
if (tv.id === other.id) { if (tv.id === other.id) {
@ -308,6 +324,43 @@ export class TCon extends TypeBase {
} }
export class TTupleIndex extends TypeBase {
public readonly kind = TypeKind.TupleIndex;
public constructor(
public tupleType: Type,
public index: number,
public node: Syntax | null = null,
) {
super();
}
public getTypeVars(): Iterable<TVar> {
return this.tupleType.getTypeVars();
}
public substitute(sub: TVSub): Type {
const newTupleType = this.tupleType.substitute(sub);
if (newTupleType === this.tupleType) {
return this;
}
return new TTupleIndex(newTupleType, this.index);
}
public shallowClone(): TTupleIndex {
return new TTupleIndex(
this.tupleType,
this.index,
);
}
public [toStringTag](_depth: number, options: InspectOptions, inspect: InspectFn): string {
return inspect(this.tupleType, options) + '.' + this.index;
}
}
export class TTuple extends TypeBase { export class TTuple extends TypeBase {
public readonly kind = TypeKind.Tuple; public readonly kind = TypeKind.Tuple;
@ -509,6 +562,7 @@ export type Type
| TNil | TNil
| TPresent | TPresent
| TAbsent | TAbsent
| TTupleIndex
export type TVar export type TVar
= TUniVar = TUniVar