Clean up type checking code a bit

This commit is contained in:
Sam Vervaeck 2023-08-02 10:37:13 +02:00
parent 89feeaadb6
commit 859b1676fd
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
4 changed files with 50 additions and 251 deletions

View file

@ -26,13 +26,12 @@ import {
TypeclassDeclaredTwiceDiagnostic, TypeclassDeclaredTwiceDiagnostic,
FieldNotFoundDiagnostic, FieldNotFoundDiagnostic,
TypeMismatchDiagnostic, TypeMismatchDiagnostic,
TupleIndexOutOfRangeDiagnostic,
} from "./diagnostics"; } from "./diagnostics";
import { assert, assertNever, isEmpty, MultiMap, toStringTag, InspectFn, implementationLimitation } from "./util"; import { assert, assertNever, isEmpty, MultiMap, toStringTag, InspectFn } from "./util";
import { Analyser } from "./analysis"; import { Analyser } from "./analysis";
import { InspectOptions } from "util"; import { InspectOptions } from "util";
import { TypeKind, TApp, TArrow, TCon, TField, TNil, TPresent, TTuple, TUniVar, TVSet, TVSub, Type, TypeBase, TAbsent, TRigidVar, TVar, TTupleIndex } from "./types"; import { TypeKind, TApp, TArrow, TCon, TField, TNil, TPresent, TUniVar, TVSet, TVSub, Type, TypeBase, TAbsent, TRigidVar, TVar, buildTupleTypeWithLoc, buildTupleType } from "./types";
import { CClass, CEmpty, CEqual, CMany, Constraint, ConstraintKind, ConstraintSet } from "./constraints"; import { CEmpty, CEqual, CMany, Constraint, ConstraintKind, ConstraintSet } from "./constraints";
// export class Qual { // export class Qual {
@ -78,10 +77,9 @@ import { CClass, CEmpty, CEqual, CMany, Constraint, ConstraintKind, ConstraintSe
// type Pred = IsInPred; // type Pred = IsInPred;
export const enum KindType { export const enum KindType {
Star, Type,
Arrow, Arrow,
Var, Var,
Row,
} }
class KVSub { class KVSub {
@ -138,17 +136,7 @@ class KVar extends KindBase {
class KType extends KindBase { class KType extends KindBase {
public readonly type = KindType.Star; public readonly type = KindType.Type;
public substitute(_sub: KVSub): Kind {
return this;
}
}
class KRow extends KindBase {
public readonly type = KindType.Row;
public substitute(_sub: KVSub): Kind { public substitute(_sub: KVSub): Kind {
return this; return this;
@ -186,8 +174,6 @@ export type Kind
= KType = KType
| KArrow | KArrow
| KVar | KVar
| KRow
abstract class SchemeBase { abstract class SchemeBase {
} }
@ -387,6 +373,7 @@ export class Checker {
private stringType = this.createTCon('String'); private stringType = this.createTCon('String');
private intType = this.createTCon('Int'); private intType = this.createTCon('Int');
private boolType = this.createTCon('Bool'); private boolType = this.createTCon('Bool');
private unitType = buildTupleType([]);
private contexts: InferContext[] = []; private contexts: InferContext[] = [];
@ -687,20 +674,6 @@ export class Checker {
case TypeKind.Absent: case TypeKind.Absent:
case TypeKind.Con: case TypeKind.Con:
return type; return type;
case TypeKind.TupleIndex:
{
const tupleType = this.simplifyType(type.tupleType);
if (tupleType.kind === TypeKind.Tuple) {
if (type.index >= tupleType.elementTypes.length) {
this.diagnostics.add(new TupleIndexOutOfRangeDiagnostic(type.index, tupleType));
return type;
}
const newType = tupleType.elementTypes[type.index];
type.set(newType);
return newType;
}
return type;
}
case TypeKind.App: case TypeKind.App:
{ {
const left = type.left.find(); const left = type.left.find();
@ -736,19 +709,6 @@ export class Checker {
} }
return new TPresent(newType, type.node); return new TPresent(newType, type.node);
} }
case TypeKind.Tuple:
{
let changed = false;
const newElementTypes = [];
for (const elementType of type.elementTypes) {
const newElementType = elementType.find();
newElementTypes.push(newElementType);
if (newElementType !== elementType) {
changed = true;
}
}
return changed ? new TTuple(newElementTypes, type.node) : type;
}
} }
} }
@ -1114,7 +1074,7 @@ export class Checker {
return this.unifyKind(b, a, node); return this.unifyKind(b, a, node);
} }
if (a.type === KindType.Star && b.type === KindType.Star) { if (a.type === KindType.Type && b.type === KindType.Type) {
return true; return true;
} }
@ -1185,7 +1145,7 @@ export class Checker {
{ {
let type; let type;
if (node.expression === null) { if (node.expression === null) {
type = new TTuple([]); type = this.unitType;
} else { } else {
type = this.inferExpression(node.expression); type = this.inferExpression(node.expression);
} }
@ -1356,7 +1316,7 @@ export class Checker {
} }
case SyntaxKind.TupleExpression: case SyntaxKind.TupleExpression:
type = new TTuple(node.elements.map(el => this.inferExpression(el)), node); type = buildTupleTypeWithLoc(node.elements.map(el => [el, this.inferExpression(el)]), node);
break; break;
case SyntaxKind.ReferenceExpression: case SyntaxKind.ReferenceExpression:
@ -1386,27 +1346,27 @@ export class Checker {
{ {
type = this.inferExpression(node.expression); type = this.inferExpression(node.expression);
for (const [_dot, name] of node.path) { for (const [_dot, name] of node.path) {
let label;
switch (name.kind) { switch (name.kind) {
case SyntaxKind.Identifier: case SyntaxKind.Identifier:
{ label = name.text;
const newFieldType = this.createTypeVar(name);
const newRestType = this.createTypeVar();
this.addConstraint(
new CEqual(
type,
new TField(name.text, new TPresent(newFieldType), newRestType, name),
node,
)
);
type = newFieldType;
break; break;
}
case SyntaxKind.Integer: case SyntaxKind.Integer:
type = new TTupleIndex(type, Number(name.value)); label = Number(name.value);
break; break;
default: default:
assertNever(name); assertNever(name);
} }
const newFieldType = this.createTypeVar(name);
const newRestType = this.createTypeVar();
this.addConstraint(
new CEqual(
type,
new TField(label, new TPresent(newFieldType), newRestType, name),
node,
)
);
type = newFieldType;
} }
break; break;
} }
@ -1542,7 +1502,7 @@ export class Checker {
case SyntaxKind.TupleTypeExpression: case SyntaxKind.TupleTypeExpression:
{ {
type = new TTuple(node.elements.map(el => this.inferTypeExpression(el, introduceTypeVars)), node); type = buildTupleTypeWithLoc(node.elements.map(el => [el, this.inferTypeExpression(el, introduceTypeVars)]), node);
break; break;
} }
@ -1579,10 +1539,11 @@ export class Checker {
case SyntaxKind.TypeExpressionWithConstraints: case SyntaxKind.TypeExpressionWithConstraints:
{ {
for (const constraint of node.constraints) { // TODO
implementationLimitation(constraint.types.length === 1); // for (const constraint of node.constraints) {
this.addConstraint(new CClass(constraint.name.text, this.inferTypeExpression(constraint.types[0]), constraint.name)); // implementationLimitation(constraint.types.length === 1);
} // this.addConstraint(new CClass(constraint.name.text, this.inferTypeExpression(constraint.types[0]), constraint.name));
// }
return this.inferTypeExpression(node.typeExpr, introduceTypeVars); return this.inferTypeExpression(node.typeExpr, introduceTypeVars);
} }
@ -1842,9 +1803,9 @@ export class Checker {
switch (member.kind) { switch (member.kind) {
case SyntaxKind.EnumDeclarationTupleElement: case SyntaxKind.EnumDeclarationTupleElement:
{ {
const argTypes = member.elements.map(el => this.inferTypeExpression(el, false)); const args: Array<[Syntax, Type]> = member.elements.map(el => [el, this.inferTypeExpression(el, false)]);
elementType = new TTuple(argTypes, member); elementType = buildTupleTypeWithLoc(args, member);
ctorType = TArrow.build(argTypes, appliedType, member); ctorType = TArrow.build(args.map(a => a[1]), appliedType, member);
break; break;
} }
case SyntaxKind.EnumDeclarationStructElement: case SyntaxKind.EnumDeclarationStructElement:
@ -2039,7 +2000,7 @@ export class Checker {
} }
private path: string[] = []; private path: (string | number)[] = [];
private constraint: Constraint | null = null; private constraint: Constraint | null = null;
private maxTypeErrorCount = 5; private maxTypeErrorCount = 5;
@ -2179,22 +2140,6 @@ export class Checker {
return success; return success;
} }
if (left.kind === TypeKind.Tuple && right.kind === TypeKind.Tuple) {
if (left.elementTypes.length === right.elementTypes.length) {
let success = false;
const count = left.elementTypes.length;
for (let i = 0; i < count; i++) {
if (!this.unify(left.elementTypes[i], right.elementTypes[i], enableDiagnostics)) {
success = false;
}
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
}
if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) { if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) {
if (left.id === right.id) { if (left.id === right.id) {
TypeBase.join(left, right); TypeBase.join(left, right);

View file

@ -8,7 +8,6 @@ export const enum ConstraintKind {
Equal, Equal,
Many, Many,
Empty, Empty,
Class,
} }
abstract class ConstraintBase { abstract class ConstraintBase {
@ -113,32 +112,6 @@ export class CMany extends ConstraintBase {
} }
export class CClass extends ConstraintBase {
public readonly kind = ConstraintKind.Class;
public constructor(
public className: string,
public type: Type,
public node: Syntax | null = null,
) {
super();
}
public substitute(sub: TVSub): CClass {
return new CClass(this.className, this.type.substitute(sub));
}
public freeTypeVars(): Iterable<TVar> {
return this.type.getTypeVars();
}
public [toStringTag](_depth: number, options: InspectOptions, inspect: InspectFn) {
return this.className + ' ' + inspect(this.type, options);
}
}
export class CEmpty extends ConstraintBase { export class CEmpty extends ConstraintBase {
public readonly kind = ConstraintKind.Empty; public readonly kind = ConstraintKind.Empty;
@ -161,7 +134,6 @@ export type Constraint
= CEqual = CEqual
| CMany | CMany
| CEmpty | CEmpty
| CClass
export class ConstraintSet extends Array<Constraint> { export class ConstraintSet extends Array<Constraint> {

View file

@ -1,6 +1,6 @@
import { Kind, KindType } from "./checker"; import { Kind, KindType } from "./checker";
import { type Type, TypeKind, TTuple } from "./types" import { type Type, TypeKind } from "./types"
import { ClassConstraint, ClassDeclaration, IdentifierAlt, InstanceDeclaration, Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst"; import { ClassConstraint, ClassDeclaration, IdentifierAlt, InstanceDeclaration, Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst";
import { assertNever, countDigits, IndentWriter } from "./util"; import { assertNever, countDigits, IndentWriter } from "./util";
@ -186,7 +186,7 @@ export class TypeMismatchDiagnostic extends DiagnosticBase {
public left: Type, public left: Type,
public right: Type, public right: Type,
public trace: Syntax[], public trace: Syntax[],
public fieldPath: string[], public fieldPath: (string | number)[],
) { ) {
super(); super();
} }
@ -197,25 +197,6 @@ export class TypeMismatchDiagnostic extends DiagnosticBase {
} }
export class TupleIndexOutOfRangeDiagnostic extends DiagnosticBase {
public readonly kind = DiagnosticKind.TupleIndexOutOfRange;
public level = Level.Error;
public constructor(
public index: number,
public tupleType: TTuple,
) {
super();
}
public get position(): TextPosition | undefined {
return undefined;
}
}
export class FieldNotFoundDiagnostic extends DiagnosticBase { export class FieldNotFoundDiagnostic extends DiagnosticBase {
public readonly kind = DiagnosticKind.FieldNotFound; public readonly kind = DiagnosticKind.FieldNotFound;
@ -223,7 +204,7 @@ export class FieldNotFoundDiagnostic extends DiagnosticBase {
public level = Level.Error; public level = Level.Error;
public constructor( public constructor(
public fieldName: string, public fieldName: string | number,
public missing: Syntax | null, public missing: Syntax | null,
public present: Syntax | null, public present: Syntax | null,
public cause: Syntax | null = null, public cause: Syntax | null = null,
@ -283,7 +264,6 @@ export type Diagnostic
| TypeclassNotImplementedDiagnostic | TypeclassNotImplementedDiagnostic
| BindingNotFoundDiagnostic | BindingNotFoundDiagnostic
| TypeMismatchDiagnostic | TypeMismatchDiagnostic
| TupleIndexOutOfRangeDiagnostic
| UnexpectedTokenDiagnostic | UnexpectedTokenDiagnostic
| FieldNotFoundDiagnostic | FieldNotFoundDiagnostic
| KindMismatchDiagnostic | KindMismatchDiagnostic
@ -494,9 +474,6 @@ const DESCRIPTIONS: Partial<Record<SyntaxKind, string>> = {
[SyntaxKind.MatchKeyword]: "'match'", [SyntaxKind.MatchKeyword]: "'match'",
[SyntaxKind.TypeKeyword]: "'type'", [SyntaxKind.TypeKeyword]: "'type'",
[SyntaxKind.IdentifierAlt]: 'an identifier starting with an uppercase letter', [SyntaxKind.IdentifierAlt]: 'an identifier starting with an uppercase letter',
[SyntaxKind.ConstantExpression]: 'a constant expression',
[SyntaxKind.ReferenceExpression]: 'a reference expression',
[SyntaxKind.LineFoldEnd]: 'the end of the current line-fold',
[SyntaxKind.TupleExpression]: 'a tuple expression such as (1, 2)', [SyntaxKind.TupleExpression]: 'a tuple expression such as (1, 2)',
[SyntaxKind.ReferenceExpression]: 'a reference to some variable', [SyntaxKind.ReferenceExpression]: 'a reference to some variable',
[SyntaxKind.NestedExpression]: 'an expression nested with parentheses', [SyntaxKind.NestedExpression]: 'an expression nested with parentheses',
@ -558,19 +535,6 @@ export function describeType(type: Type): string {
{ {
return describeType(type.paramType) + ' -> ' + describeType(type.returnType); return describeType(type.paramType) + ' -> ' + describeType(type.returnType);
} }
case TypeKind.Tuple:
{
let out = '(';
let first = true;
for (const elementType of type.elementTypes) {
if (first) first = false;
else out += ', ';
out += describeType(elementType);
}
return out + ')';
}
case TypeKind.TupleIndex:
return describeType(type.tupleType) + '.' + type.index;
case TypeKind.Field: case TypeKind.Field:
{ {
let out = '{ ' + type.name + ': ' + describeType(type.type); let out = '{ ' + type.name + ': ' + describeType(type.type);
@ -605,7 +569,7 @@ function describeKind(kind: Kind): string {
return `k${kind.id}`; return `k${kind.id}`;
case KindType.Arrow: case KindType.Arrow:
return describeKind(kind.left) + ' -> ' + describeKind(kind.right); return describeKind(kind.left) + ' -> ' + describeKind(kind.right);
case KindType.Star: case KindType.Type:
return '*'; return '*';
default: default:
assertNever(kind); assertNever(kind);

View file

@ -7,8 +7,6 @@ export enum TypeKind {
UniVar, UniVar,
RigidVar, RigidVar,
Con, Con,
Tuple,
TupleIndex,
App, App,
Nominal, Nominal,
Field, Field,
@ -311,84 +309,20 @@ export class TCon extends TypeBase {
} }
export class TTupleIndex extends TypeBase { export function buildTupleType(types: Type[]): Type {
let out: Type = new TNil();
public readonly kind = TypeKind.TupleIndex; types.forEach((type, i) => {
out = new TField(i, new TPresent(type), out);
public constructor( });
public tupleType: Type, return out;
public index: number,
public node: Syntax | null = null,
) {
super();
}
public getTypeVars(): Iterable<TVar> {
return this.tupleType.getTypeVars();
}
public substitute(sub: TVSub): Type {
const newTupleType = this.tupleType.substitute(sub);
if (newTupleType === this.tupleType) {
return this;
}
return new TTupleIndex(newTupleType, this.index);
}
public shallowClone(): TTupleIndex {
return new TTupleIndex(
this.tupleType,
this.index,
);
}
public [toStringTag](_depth: number, options: InspectOptions, inspect: InspectFn): string {
return inspect(this.tupleType, options) + '.' + this.index;
}
} }
export class TTuple extends TypeBase { export function buildTupleTypeWithLoc(elements: Array<[Syntax, Type]>, node: Syntax) {
let out: Type = new TNil(node);
public readonly kind = TypeKind.Tuple; elements.forEach(([el, type], i) => {
out = new TField(i, new TPresent(type, el), out);
public constructor( });
public elementTypes: Type[], return out;
public node: Syntax | null = null,
) {
super();
}
public *getTypeVars(): Iterable<TVar> {
for (const elementType of this.elementTypes) {
yield* elementType.getTypeVars();
}
}
public shallowClone(): TTuple {
return new TTuple(
this.elementTypes,
this.node,
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newElementTypes = [];
for (const elementType of this.elementTypes) {
const newElementType = elementType.substitute(sub);
if (newElementType !== elementType) {
changed = true;
}
newElementTypes.push(newElementType);
}
return changed ? new TTuple(newElementTypes, this.node) : this;
}
public [toStringTag](_depth: number, options: InspectOptions, inspect: InspectFn) {
return this.elementTypes.map(t => inspect(t, options)).join(' × ');
}
} }
export class TField extends TypeBase { export class TField extends TypeBase {
@ -396,7 +330,7 @@ export class TField extends TypeBase {
public readonly kind = TypeKind.Field; public readonly kind = TypeKind.Field;
public constructor( public constructor(
public name: string, public name: string | number,
public type: Type, public type: Type,
public restType: Type, public restType: Type,
public node: Syntax | null = null, public node: Syntax | null = null,
@ -426,7 +360,7 @@ export class TField extends TypeBase {
} }
public static sort(type: Type): Type { public static sort(type: Type): Type {
const fields = new Map<string, TField>(); const fields = new Map<string | number, TField>();
while (type.kind === TypeKind.Field) { while (type.kind === TypeKind.Field) {
fields.set(type.name, type); fields.set(type.name, type);
type = type.restType; type = type.restType;
@ -518,13 +452,11 @@ export type Type
| TArrow | TArrow
| TRigidVar | TRigidVar
| TUniVar | TUniVar
| TTuple
| TApp | TApp
| TField | TField
| TNil | TNil
| TPresent | TPresent
| TAbsent | TAbsent
| TTupleIndex
export type TVar export type TVar
= TUniVar = TUniVar
@ -557,23 +489,9 @@ export function typesEqual(a: Type, b: Type): boolean {
case TypeKind.Arrow: case TypeKind.Arrow:
assert(b.kind === TypeKind.Arrow); assert(b.kind === TypeKind.Arrow);
return typesEqual(a.paramType, b.paramType) && typesEqual(a.returnType, b.returnType); return typesEqual(a.paramType, b.paramType) && typesEqual(a.returnType, b.returnType);
case TypeKind.Tuple:
assert(b.kind === TypeKind.Tuple);
if (a.elementTypes.length !== b.elementTypes.length) {
return false;
}
for (let i = 0; i < a.elementTypes.length; i++) {
if (!typesEqual(a.elementTypes[i], b.elementTypes[i])) {
return false;
}
}
return true;
case TypeKind.Present: case TypeKind.Present:
assert(b.kind === TypeKind.Present); assert(b.kind === TypeKind.Present);
return typesEqual(a.type, b.type); return typesEqual(a.type, b.type);
case TypeKind.TupleIndex:
assert(b.kind === TypeKind.TupleIndex);
return a.index === b.index && typesEqual(a.tupleType, b.tupleType);
default: default:
assertNever(a); assertNever(a);
} }