bolt/src/checker.ts

2265 lines
58 KiB
TypeScript
Raw Normal View History

import {
EnumDeclaration,
EnumDeclarationElement,
Expression,
2022-09-05 19:38:55 +02:00
LetDeclaration,
Pattern,
2022-09-07 12:45:38 +02:00
Scope,
SourceFile,
StructDeclaration,
Symkind,
Syntax,
SyntaxKind,
TypeExpression
} from "./cst";
import {
describeType,
ArityMismatchDiagnostic,
BindingNotFoudDiagnostic,
Diagnostics,
FieldDoesNotExistDiagnostic,
FieldMissingDiagnostic,
UnificationFailedDiagnostic,
KindMismatchDiagnostic
} from "./diagnostics";
import { assert, isEmpty } from "./util";
import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl"
const MAX_TYPE_ERROR_COUNT = 5;
type NodeWithBindings = SourceFile | LetDeclaration;
type ReferenceGraph = LabeledGraph<NodeWithBindings, boolean>;
export enum TypeKind {
Arrow,
Var,
Con,
Any,
Tuple,
2022-09-07 12:45:38 +02:00
Labeled,
Record,
App,
Variant,
}
abstract class TypeBase {
public abstract readonly kind: TypeKind;
public next: Type = this as any;
public constructor(
public node: Syntax | null = null
) {
}
public static join(a: Type, b: Type): void {
const keep = a.next;
a.next = b;
b.next = keep;
}
public abstract getTypeVars(): Iterable<TVar>;
public abstract shallowClone(): Type;
public abstract substitute(sub: TVSub): Type;
public hasTypeVar(tv: TVar): boolean {
for (const other of this.getTypeVars()) {
if (tv.id === other.id) {
return true;
}
}
return false;
}
}
class TVar extends TypeBase {
public readonly kind = TypeKind.Var;
public constructor(
public id: number,
public node: Syntax | null = null,
) {
super();
}
public *getTypeVars(): Iterable<TVar> {
yield this;
}
public shallowClone(): TVar {
return new TVar(this.id, this.node);
}
public substitute(sub: TVSub): Type {
const other = sub.get(this);
return other === undefined
? this : other.substitute(sub);
}
}
export class TArrow extends TypeBase {
public readonly kind = TypeKind.Arrow;
public constructor(
public paramTypes: Type[],
public returnType: Type,
public node: Syntax | null = null,
) {
super();
}
public *getTypeVars(): Iterable<TVar> {
for (const paramType of this.paramTypes) {
yield* paramType.getTypeVars();
}
yield* this.returnType.getTypeVars();
}
public shallowClone(): TArrow {
return new TArrow(
this.paramTypes,
this.returnType,
this.node,
)
}
public substitute(sub: TVSub): Type {
let changed = false;
const newParamTypes = [];
for (const paramType of this.paramTypes) {
const newParamType = paramType.substitute(sub);
if (newParamType !== paramType) {
changed = true;
}
newParamTypes.push(newParamType);
}
const newReturnType = this.returnType.substitute(sub);
if (newReturnType !== this.returnType) {
changed = true;
}
return changed ? new TArrow(newParamTypes, newReturnType, this.node) : this;
}
}
export class TCon extends TypeBase {
public readonly kind = TypeKind.Con;
public constructor(
public id: number,
public argTypes: Type[],
public displayName: string,
public node: Syntax | null = null,
) {
super(node);
}
public *getTypeVars(): Iterable<TVar> {
for (const argType of this.argTypes) {
yield* argType.getTypeVars();
}
}
public shallowClone(): TCon {
return new TCon(
this.id,
this.argTypes,
this.displayName,
this.node,
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newArgTypes = [];
for (const argType of this.argTypes) {
const newArgType = argType.substitute(sub);
if (newArgType !== argType) {
changed = true;
}
newArgTypes.push(newArgType);
}
return changed ? new TCon(this.id, newArgTypes, this.displayName, this.node) : this;
}
}
class TTuple extends TypeBase {
public readonly kind = TypeKind.Tuple;
public constructor(
public elementTypes: Type[],
public node: Syntax | null = null,
) {
super(node);
}
public *getTypeVars(): Iterable<TVar> {
for (const elementType of this.elementTypes) {
yield* elementType.getTypeVars();
}
}
public shallowClone(): TTuple {
return new TTuple(
this.elementTypes,
this.node,
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newElementTypes = [];
for (const elementType of this.elementTypes) {
const newElementType = elementType.substitute(sub);
if (newElementType !== elementType) {
changed = true;
}
newElementTypes.push(newElementType);
}
return changed ? new TTuple(newElementTypes, this.node) : this;
}
}
export class TLabeled extends TypeBase {
2022-09-07 12:45:38 +02:00
public readonly kind = TypeKind.Labeled;
public fields?: Map<string, Type>;
public parent: TLabeled | null = null;
public constructor(
public name: string,
public type: Type,
public node: Syntax | null = null,
2022-09-07 12:45:38 +02:00
) {
super(node);
2022-09-07 12:45:38 +02:00
}
public find(): TLabeled {
let curr: TLabeled | null = this;
while (curr.parent !== null) {
curr = curr.parent;
}
this.parent = curr;
return curr;
}
public getTypeVars(): Iterable<TVar> {
return this.type.getTypeVars();
}
public shallowClone(): TLabeled {
return new TLabeled(
this.name,
this.type,
this.node,
);
}
2022-09-07 12:45:38 +02:00
public substitute(sub: TVSub): Type {
const newType = this.type.substitute(sub);
return newType !== this.type ? new TLabeled(this.name, newType, this.node) : this;
2022-09-07 12:45:38 +02:00
}
}
export class TRecord extends TypeBase {
2022-09-07 12:45:38 +02:00
public readonly kind = TypeKind.Record;
public constructor(
public decl: Syntax,
public kindArgs: TVar[],
2022-09-07 12:45:38 +02:00
public fields: Map<string, Type>,
public node: Syntax | null = null,
2022-09-07 12:45:38 +02:00
) {
super(node);
2022-09-07 12:45:38 +02:00
}
public *getTypeVars(): Iterable<TVar> {
for (const type of this.fields.values()) {
yield* type.getTypeVars();
}
}
public shallowClone(): TRecord {
return new TRecord(
this.decl,
this.kindArgs,
this.fields,
this.node
);
}
2022-09-07 12:45:38 +02:00
public substitute(sub: TVSub): Type {
let changed = false;
const newTypeVars = [];
for (const typeVar of this.kindArgs) {
const newTypeVar = typeVar.substitute(sub);
assert(newTypeVar.kind === TypeKind.Var);
if (newTypeVar !== typeVar) {
changed = true;
}
newTypeVars.push(newTypeVar);
}
2022-09-07 12:45:38 +02:00
const newFields = new Map();
for (const [key, type] of this.fields) {
const newType = type.substitute(sub);
if (newType !== type) {
changed = true;
}
newFields.set(key, newType);
}
return changed ? new TRecord(this.decl, newTypeVars, newFields, this.node) : this;
2022-09-07 12:45:38 +02:00
}
}
export class TApp extends TypeBase {
public readonly kind = TypeKind.App;
public constructor(
public left: Type,
public right: Type,
public node: Syntax | null = null
) {
super(node);
}
public static build(types: Type[]): Type {
let result = types[0];
for (let i = 1; i < types.length; i++) {
result = new TApp(result, types[i]);
}
return result;
}
public *getTypeVars(): Iterable<TVar> {
yield* this.left.getTypeVars();
yield* this.right.getTypeVars();
}
public shallowClone() {
return new TApp(
this.left,
this.right,
this.node
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newOperatorType = this.left.substitute(sub);
if (newOperatorType !== this.left) {
changed = true;
}
const newArgType = this.right.substitute(sub);
if (newArgType !== this.right) {
changed = true;
}
return changed ? new TApp(newOperatorType, newArgType, this.node) : this;
}
}
export class TVariant extends TypeBase {
public readonly kind = TypeKind.Variant;
public constructor(
public decl: Syntax,
public kindArgs: Type[],
public elementTypes: Type[],
public node: Syntax | null = null,
) {
super(node);
}
public *getTypeVars(): Iterable<TVar> {
for (const elementType of this.elementTypes) {
yield* elementType.getTypeVars();
}
}
public shallowClone(): Type {
return new TVariant(
this.decl,
this.kindArgs,
this.elementTypes,
this.node,
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newTypeVars = [];
for (const kindArg of this.kindArgs) {
const newTypeVar = kindArg.substitute(sub);
assert(newTypeVar.kind === TypeKind.Var);
if (newTypeVar !== kindArg) {
changed = true;
}
newTypeVars.push(newTypeVar);
}
const newElementTypes = [];
for (const elementType of this.elementTypes) {
const newElementType = elementType.substitute(sub);
if (newElementType !== elementType) {
changed = true;
}
newElementTypes.push(newElementType);
}
return changed ? new TVariant(this.decl, newTypeVars, newElementTypes, this.node) : this;
}
}
export type Type
= TCon
| TArrow
| TVar
| TTuple
2022-09-07 12:45:38 +02:00
| TLabeled
| TRecord
| TApp
| TVariant
type KindedType
= TRecord
| TVariant
function isKindedType(type: Type): type is KindedType {
return type.kind === TypeKind.Variant
|| type.kind === TypeKind.Record;
}
export const enum KindType {
Star,
Arrow,
Var,
}
class KVSub {
private mapping = new Map<number, Kind>();
public set(kv: KVar, kind: Kind): void {
this.mapping.set(kv.id, kind);
}
public get(kv: KVar): Kind | undefined {
return this.mapping.get(kv.id);
}
public has(kv: KVar): boolean {
return this.mapping.has(kv.id);
}
public values(): Iterable<Kind> {
return this.mapping.values();
}
}
abstract class KindBase {
public abstract readonly type: KindType;
public abstract substitute(sub: KVSub): Kind;
}
class KVar extends KindBase {
public readonly type = KindType.Var;
public constructor(
public id: number,
) {
super();
}
public substitute(sub: KVSub): Kind {
const other = sub.get(this);
return other === undefined
? this : other.substitute(sub);
}
}
class KStar extends KindBase {
public readonly type = KindType.Star;
public substitute(_sub: KVSub): Kind {
return this;
}
}
class KArrow extends KindBase {
public readonly type = KindType.Arrow;
public constructor(
public left: Kind,
public right: Kind,
) {
super();
}
public substitute(sub: KVSub): Kind {
return new KArrow(
this.left.substitute(sub),
this.right.substitute(sub),
);
}
}
export type Kind
= KStar
| KArrow
| KVar
class TVSet {
private mapping = new Map<number, TVar>();
public add(tv: TVar): void {
this.mapping.set(tv.id, tv);
}
public has(tv: TVar): boolean {
return this.mapping.has(tv.id);
}
public intersectsType(type: Type): boolean {
for (const tv of type.getTypeVars()) {
if (this.has(tv)) {
return true;
}
}
return false;
}
public delete(tv: TVar): void {
this.mapping.delete(tv.id);
}
public [Symbol.iterator](): Iterator<TVar> {
return this.mapping.values();
}
}
class TVSub {
private mapping = new Map<number, Type>();
public set(tv: TVar, type: Type): void {
this.mapping.set(tv.id, type);
}
public get(tv: TVar): Type | undefined {
return this.mapping.get(tv.id);
}
public has(tv: TVar): boolean {
return this.mapping.has(tv.id);
}
public delete(tv: TVar): void {
this.mapping.delete(tv.id);
}
public values(): Iterable<Type> {
return this.mapping.values();
}
}
const enum ConstraintKind {
Equal,
Many,
2022-09-07 12:45:38 +02:00
Shaped,
}
abstract class ConstraintBase {
public abstract substitute(sub: TVSub): Constraint;
public constructor(
public node: Syntax | null = null
) {
}
public prevInstantiation: Constraint | null = null;
public *getNodes(): Iterable<Syntax> {
let curr: Constraint | null = this as any;
while (curr !== null) {
if (curr.node !== null) {
yield curr.node;
}
curr = curr.prevInstantiation;
}
}
}
2022-09-07 12:45:38 +02:00
class CShaped extends ConstraintBase {
public readonly kind = ConstraintKind.Shaped;
public constructor(
public recordType: TLabeled,
public type: Type,
) {
super();
}
public substitute(sub: TVSub): Constraint {
return new CShaped(
this.recordType.substitute(sub) as TLabeled,
this.type.substitute(sub),
);
}
}
class CEqual extends ConstraintBase {
public readonly kind = ConstraintKind.Equal;
public constructor(
public left: Type,
public right: Type,
public node: Syntax,
) {
super();
}
public substitute(sub: TVSub): Constraint {
return new CEqual(
this.left.substitute(sub),
this.right.substitute(sub),
this.node,
);
}
public dump(): void {
console.error(`${describeType(this.left)} ~ ${describeType(this.right)}`);
}
}
class CMany extends ConstraintBase {
public readonly kind = ConstraintKind.Many;
public constructor(
public elements: Constraint[]
) {
super();
}
public substitute(sub: TVSub): Constraint {
const newElements = [];
for (const element of this.elements) {
newElements.push(element.substitute(sub));
}
return new CMany(newElements);
}
}
type Constraint
= CEqual
| CMany
2022-09-07 12:45:38 +02:00
| CShaped
class ConstraintSet extends Array<Constraint> {
}
abstract class SchemeBase {
}
class Forall extends SchemeBase {
public constructor(
public typeVars: Iterable<TVar>,
public constraints: Iterable<Constraint>,
public type: Type,
) {
super();
}
}
type Scheme
= Forall
2022-09-05 19:38:55 +02:00
export class TypeEnv {
private mapping = new Map<string, Scheme>();
public constructor(public parent: TypeEnv | null = null) {
}
public add(name: string, scheme: Scheme): void {
this.mapping.set(name, scheme);
}
public lookup(name: string): Scheme | null {
let curr: TypeEnv | null = this;
do {
const scheme = curr.mapping.get(name);
if (scheme !== undefined) {
return scheme;
}
curr = curr.parent;
} while(curr !== null);
return null;
}
}
class KindEnv {
private mapping1 = new Map<string, Kind>();
private mapping2 = new Map<number, Kind>();
public constructor(public parent: KindEnv | null = null) {
}
public setNamed(name: string, kind: Kind): void {
assert(!this.mapping1.has(name));
this.mapping1.set(name, kind);
}
public setVar(tv: TVar, kind: Kind): void {
assert(!this.mapping2.has(tv.id));
this.mapping2.set(tv.id, kind);
}
public lookupNamed(name: string): Kind | null {
let curr: KindEnv | null = this;
do {
const kind = curr.mapping1.get(name);
if (kind !== undefined) {
return kind;
}
curr = curr.parent;
} while (curr !== null);
return null;
}
public lookupVar(tv: TVar): Kind | null {
return this.mapping2.get(tv.id) ?? null;
}
}
export interface InferContext {
typeVars: TVSet;
env: TypeEnv;
constraints: ConstraintSet;
returnType: Type | null;
}
function isFunctionDeclarationLike(node: LetDeclaration): boolean {
return node.pattern.kind === SyntaxKind.BindPattern
&& (node.params.length > 0 || (node.body !== null && node.body.kind === SyntaxKind.BlockBody));
}
export class Checker {
private nextTypeVarId = 0;
private nextKindVarId = 0;
private nextConTypeId = 0;
private stringType = new TCon(this.nextConTypeId++, [], 'String');
private intType = new TCon(this.nextConTypeId++, [], 'Int');
private boolType = new TCon(this.nextConTypeId++, [], 'Bool');
private contexts: InferContext[] = [];
private solution = new TVSub();
private kindSolution = new KVSub();
public constructor(
private diagnostics: Diagnostics
) {
}
public getIntType(): Type {
return this.intType;
}
public getStringType(): Type {
return this.stringType;
}
public getBoolType(): Type {
return this.boolType;
}
private createTypeVar(): TVar {
const typeVar = new TVar(this.nextTypeVarId++);
const context = this.contexts[this.contexts.length-1];
context.typeVars.add(typeVar);
return typeVar;
}
private addConstraint(constraint: Constraint): void {
this.contexts[this.contexts.length-1].constraints.push(constraint);
}
private pushContext(context: InferContext) {
this.contexts.push(context);
}
private popContext(context: InferContext) {
assert(this.contexts[this.contexts.length-1] === context);
this.contexts.pop();
}
private lookup(name: string): Scheme | null {
const context = this.contexts[this.contexts.length-1];
return context.env.lookup(name);
}
private getReturnType(): Type {
const context = this.contexts[this.contexts.length-1];
assert(context && context.returnType !== null);
return context.returnType;
}
private createSubstitution(scheme: Scheme): TVSub {
const sub = new TVSub();
2022-09-10 16:52:14 +02:00
for (const tv of scheme.typeVars) {
sub.set(tv, this.createTypeVar());
}
return sub;
}
private instantiate(scheme: Scheme, node: Syntax | null, sub = this.createSubstitution(scheme)): Type {
for (const constraint of scheme.constraints) {
const substituted = constraint.substitute(sub);
substituted.node = node;
substituted.prevInstantiation = constraint;
this.addConstraint(substituted);
}
return scheme.type.substitute(sub);
}
private addBinding(name: string, scheme: Scheme): void {
const context = this.contexts[this.contexts.length-1];
context.env.add(name, scheme);
}
private inferKindFromTypeExpression(node: TypeExpression, env: KindEnv): Kind {
switch (node.kind) {
case SyntaxKind.VarTypeExpression:
case SyntaxKind.ReferenceTypeExpression:
{
const kind = env.lookupNamed(node.name.text);
if (kind === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
// Create a filler kind variable that still will be able to catch other errors.
return this.createKindVar();
}
return kind;
}
case SyntaxKind.AppTypeExpression:
{
let result = this.inferKindFromTypeExpression(node.operator, env);;
for (const arg of node.args) {
result = this.applyKind(result, this.inferKindFromTypeExpression(arg, env), node);
}
return result;
}
case SyntaxKind.NestedTypeExpression:
{
return this.inferKindFromTypeExpression(node.typeExpr, env);
}
default:
throw new Error(`Unexpected ${node}`);
}
}
private createKindVar(): KVar {
return new KVar(this.nextKindVarId++);
}
private applyKind(operator: Kind, arg: Kind, node: Syntax): Kind {
switch (operator.type) {
case KindType.Var:
{
const a1 = this.createKindVar();
const a2 = this.createKindVar();
const arrow = new KArrow(a1, a2);
this.unifyKind(arrow, operator, node);
this.unifyKind(a1, arg, node);
return a2;
}
case KindType.Arrow:
{
// Unify the argument to the operator's argument kind and return
// whatever the operator returns.
this.unifyKind(operator.left, arg, node);
return operator.right;
}
case KindType.Star:
{
this.diagnostics.add(
new KindMismatchDiagnostic(
operator,
new KArrow(
this.createKindVar(),
this.createKindVar()
),
node
)
);
// Create a filler kind variable that still will be able to catch other errors.
return this.createKindVar();
}
}
}
private forwardDeclareKind(node: Syntax, env: KindEnv): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.forwardDeclareKind(element, env);
}
break;
}
case SyntaxKind.StructDeclaration:
case SyntaxKind.EnumDeclaration:
{
env.setNamed(node.name.text, this.createKindVar());
if (node.members !== null) {
for (const member of node.members) {
env.setNamed(member.name.text, this.createKindVar());
}
}
break;
}
}
}
private inferKind(node: Syntax, env: KindEnv): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.inferKind(element, env);
}
break;
}
case SyntaxKind.StructDeclaration:
{
// TODO
break;
}
case SyntaxKind.EnumDeclaration:
{
const declKind = env.lookupNamed(node.name.text)!;
const innerEnv = new KindEnv(env);
let kind: Kind = new KStar();
// FIXME should I go from right to left or left to right?
for (let i = node.varExps.length-1; i >= 0; i--) {
const varExpr = node.varExps[i];
const paramKind = this.createKindVar();
innerEnv.setNamed(varExpr.text, paramKind);
kind = new KArrow(paramKind, kind);
}
this.unifyKind(declKind, kind, node);
if (node.members !== null) {
for (const member of node.members) {
switch (member.kind) {
case SyntaxKind.EnumDeclarationTupleElement:
{
for (const element of member.elements) {
this.unifyKind(this.inferKindFromTypeExpression(element, innerEnv), new KStar(), element);
}
break;
}
// TODO
}
}
}
break;
}
case SyntaxKind.LetDeclaration:
{
if (node.typeAssert !== null) {
this.unifyKind(this.inferKindFromTypeExpression(node.typeAssert.typeExpression, env), new KStar(), node.typeAssert.typeExpression);
}
if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) {
for (const element of node.body.elements) {
// TODO fork `env` to support local type declarations
this.inferKind(element, env);
}
}
break;
}
}
}
private unifyKind(a: Kind, b: Kind, node: Syntax): boolean {
const find = (kind: Kind): Kind => {
let curr = kind;
while (curr.type === KindType.Var && this.kindSolution.has(curr)) {
curr = this.kindSolution.get(curr)!;
}
// if (kind.type === KindType.Var && ) {
// this.kindSolution.set(kind.id, curr);
// }
return curr;
}
const solve = (kind: Kind) => kind.substitute(this.kindSolution);
a = find(a);
b = find(b);
if (a.type === KindType.Var) {
this.kindSolution.set(a, b);
return true;
}
if (b.type === KindType.Var) {
return this.unifyKind(b, a, node);
}
if (a.type === KindType.Star && b.type === KindType.Star) {
return true;
}
if (a.type === KindType.Arrow && b.type === KindType.Arrow) {
return this.unifyKind(a.left, b.left, node)
|| this.unifyKind(a.right, b.right, node);
}
this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node));
return false;
}
private infer(node: Syntax): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.infer(element);
}
break;
}
case SyntaxKind.ExpressionStatement:
{
this.inferExpression(node.expression);
break;
}
case SyntaxKind.IfStatement:
{
for (const cs of node.cases) {
if (cs.test !== null) {
this.addConstraint(
new CEqual(
this.inferExpression(cs.test),
this.getBoolType(),
cs.test
)
);
}
for (const element of cs.elements) {
this.infer(element);
}
}
break;
}
case SyntaxKind.ReturnStatement:
{
let type;
if (node.expression === null) {
type = new TTuple([]);
} else {
type = this.inferExpression(node.expression);
}
this.addConstraint(
new CEqual(
this.getReturnType(),
type,
node
)
);
break;
}
case SyntaxKind.LetDeclaration:
2022-09-07 12:45:38 +02:00
{
if (isFunctionDeclarationLike(node)) {
2022-09-07 12:45:38 +02:00
break;
}
let type;
if (node.pattern.kind === SyntaxKind.WrappedOperator) {
type = this.createTypeVar();
this.addBinding(node.pattern.operator.text, new Forall([], [], type));
} else {
type = this.inferBindings(node.pattern, [], []);
}
2022-09-07 12:45:38 +02:00
if (node.typeAssert !== null) {
this.addConstraint(
new CEqual(
this.inferTypeExpression(node.typeAssert.typeExpression),
type,
node
)
);
}
if (node.body !== null) {
switch (node.body.kind) {
case SyntaxKind.ExprBody:
{
const type2 = this.inferExpression(node.body.expression);
this.addConstraint(
new CEqual(
type,
type2,
node
)
);
break;
}
case SyntaxKind.BlockBody:
{
// TODO
assert(false);
}
}
}
break;
}
case SyntaxKind.TypeDeclaration:
case SyntaxKind.EnumDeclaration:
2022-09-07 12:45:38 +02:00
case SyntaxKind.StructDeclaration:
break;
default:
2022-09-07 12:45:38 +02:00
throw new Error(`Unexpected ${node.constructor.name}`);
}
}
private buildVariantType(decl: EnumDeclarationElement, type: Type): Type {
const enumDecl = decl.parent as EnumDeclaration;
const kindArgs = [];
for (const _ of enumDecl.varExps) {
kindArgs.push(this.createTypeVar());
}
const variantTypes = [];
if (enumDecl.members !== null) {
for (const member of enumDecl.members) {
let variantType;
if (member === decl) {
variantType = type;
} else {
variantType = this.createTypeVar();
}
variantTypes.push(variantType);
}
}
return TApp.build([ ...kindArgs, new TVariant(enumDecl, [], []) ]);
}
public inferExpression(node: Expression): Type {
switch (node.kind) {
case SyntaxKind.NestedExpression:
return this.inferExpression(node.expression);
case SyntaxKind.ReferenceExpression:
{
assert(node.name.modulePath.length === 0);
const scope = node.getScope();
const target = scope.lookup(node.name.name.text);
2022-09-05 19:38:55 +02:00
if (target !== null && target.kind === SyntaxKind.LetDeclaration && target.active) {
return target.type!;
}
const scheme = this.lookup(node.name.name.text);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name));
return this.createTypeVar();
}
const type = this.instantiate(scheme, node);
type.node = node;
return type;
}
case SyntaxKind.MemberExpression:
{
let type = this.inferExpression(node.expression);
for (const [_dot, name] of node.path) {
const newType = this.createTypeVar();
this.addConstraint(
new CEqual(
type,
new TLabeled(name.text, newType),
node,
)
);
type = newType;
}
return type;
}
case SyntaxKind.CallExpression:
{
const opType = this.inferExpression(node.func);
const retType = this.createTypeVar();
const paramTypes = [];
for (const arg of node.args) {
paramTypes.push(this.inferExpression(arg));
}
this.addConstraint(
new CEqual(
opType,
new TArrow(paramTypes, retType),
node
)
);
return retType;
}
case SyntaxKind.ConstantExpression:
{
let ty;
switch (node.token.kind) {
case SyntaxKind.StringLiteral:
ty = this.getStringType();
break;
case SyntaxKind.Integer:
ty = this.getIntType();
break;
}
ty = ty.shallowClone();
ty.node = node;
return ty;
}
case SyntaxKind.NamedTupleExpression:
{
// TODO Only lookup constructors and skip other bindings
const scheme = this.lookup(node.name.text);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
}
const operatorType = this.instantiate(scheme, node.name);
const argTypes = node.elements.map(el => this.inferExpression(el));
const retType = this.createTypeVar();
this.addConstraint(
new CEqual(
new TArrow(
argTypes,
retType,
node,
),
operatorType,
node
)
);
return retType;
}
2022-09-07 12:45:38 +02:00
case SyntaxKind.StructExpression:
{
const scope = node.getScope();
const decl = scope.lookup(node.name.text, Symkind.Constructor);
if (decl === null) {
2022-09-07 12:45:38 +02:00
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
2022-09-07 12:45:38 +02:00
}
assert(decl.kind === SyntaxKind.StructDeclaration || decl.kind === SyntaxKind.EnumDeclarationStructElement);
const scheme = decl.scheme;
const sub = this.createSubstitution(scheme);
const declType = this.instantiate(scheme, node, sub);
const argTypes = [];
for (const typeVar of decl.tvs) {
const newTypeVar = sub.get(typeVar)!;
assert(newTypeVar.kind === TypeKind.Var);
argTypes.push(newTypeVar);
}
const fields = new Map();
2022-09-07 12:45:38 +02:00
for (const member of node.members) {
switch (member.kind) {
case SyntaxKind.StructExpressionField:
{
fields.set(member.name.text, this.inferExpression(member.expression));
2022-09-07 12:45:38 +02:00
break;
}
case SyntaxKind.PunnedStructExpressionField:
{
const scheme = this.lookup(member.name.text);
let fieldType;
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(member.name.text, member.name));
fieldType = this.createTypeVar();
2022-09-07 12:45:38 +02:00
} else {
fieldType = this.instantiate(scheme, member);
}
fields.set(member.name.text, fieldType);
2022-09-07 12:45:38 +02:00
break;
}
default:
throw new Error(`Unexpected ${member}`);
}
}
let type: Type = new TRecord(decl, argTypes, fields, node);
if (decl.kind === SyntaxKind.EnumDeclarationStructElement) {
type = this.buildVariantType(decl, type);
}
2022-09-07 12:45:38 +02:00
this.addConstraint(
new CEqual(
declType,
2022-09-07 12:45:38 +02:00
type,
node,
)
);
return type;
}
case SyntaxKind.InfixExpression:
{
const scheme = this.lookup(node.operator.text);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.operator.text, node.operator));
return this.createTypeVar();
}
const opType = this.instantiate(scheme, node.operator);
const retType = this.createTypeVar();
const leftType = this.inferExpression(node.left);
const rightType = this.inferExpression(node.right);
this.addConstraint(
new CEqual(
new TArrow([ leftType, rightType ], retType),
opType,
node,
),
);
return retType;
}
default:
throw new Error(`Unexpected ${node.constructor.name}`);
}
}
2022-09-10 16:52:14 +02:00
public inferTypeExpression(node: TypeExpression, introduceTypeVars = false): Type {
switch (node.kind) {
case SyntaxKind.ReferenceTypeExpression:
{
const scheme = this.lookup(node.name.text);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
}
const type = this.instantiate(scheme, node.name);
// FIXME it is not guaranteed that `type` is copied, so the original type might get mutated
type.node = node;
return type;
}
case SyntaxKind.NestedTypeExpression:
return this.inferTypeExpression(node.typeExpr, introduceTypeVars);
2022-09-10 16:52:14 +02:00
case SyntaxKind.VarTypeExpression:
{
const scheme = this.lookup(node.name.text);
if (scheme === null) {
if (!introduceTypeVars) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
}
const type = this.createTypeVar();
this.addBinding(node.name.text, new Forall([], [], type));
return type;
}
assert(isEmpty(scheme.typeVars));
assert(isEmpty(scheme.constraints));
2022-09-10 16:52:14 +02:00
return scheme.type;
}
case SyntaxKind.AppTypeExpression:
{
return TApp.build([
...node.args.map(arg => this.inferTypeExpression(arg, introduceTypeVars)),
this.inferTypeExpression(node.operator, introduceTypeVars),
]);
}
case SyntaxKind.ArrowTypeExpression:
{
const paramTypes = [];
for (const paramTypeExpr of node.paramTypeExprs) {
paramTypes.push(this.inferTypeExpression(paramTypeExpr, introduceTypeVars));
}
const returnType = this.inferTypeExpression(node.returnTypeExpr, introduceTypeVars);
return new TArrow(paramTypes, returnType, node);
}
default:
throw new Error(`Unrecognised ${node}`);
}
}
2022-09-07 12:45:38 +02:00
public inferBindings(pattern: Pattern, typeVars: TVar[], constraints: Constraint[]): Type {
switch (pattern.kind) {
case SyntaxKind.BindPattern:
{
2022-09-07 12:45:38 +02:00
const type = this.createTypeVar();
this.addBinding(pattern.name.text, new Forall(typeVars, constraints, type));
return type;
}
2022-09-07 12:45:38 +02:00
case SyntaxKind.StructPattern:
{
const scheme = this.lookup(pattern.name.text);
let recordType;
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(pattern.name.text, pattern.name));
recordType = this.createTypeVar();
2022-09-07 12:45:38 +02:00
} else {
recordType = this.instantiate(scheme, pattern.name);
}
const type = this.createTypeVar();
for (const member of pattern.members) {
switch (member.kind) {
case SyntaxKind.StructPatternField:
{
const fieldType = this.inferBindings(member.pattern, typeVars, constraints);
this.addConstraint(
new CEqual(
new TLabeled(member.name.text, fieldType),
type,
member
)
);
break;
}
case SyntaxKind.PunnedStructPatternField:
{
const fieldType = this.createTypeVar();
this.addBinding(member.name.text, new Forall([], [], fieldType));
this.addConstraint(
new CEqual(
new TLabeled(member.name.text, fieldType),
type,
member
)
);
break;
}
default:
throw new Error(`Unexpected ${member.constructor.name}`);
}
}
this.addConstraint(
new CEqual(
recordType,
type,
pattern
)
);
return type;
}
default:
throw new Error(`Unexpected ${pattern.constructor.name}`);
}
}
private addReferencesToGraph(graph: ReferenceGraph, node: Syntax, source: LetDeclaration | SourceFile) {
2022-09-05 19:38:55 +02:00
2022-09-07 12:45:38 +02:00
const addReference = (scope: Scope, name: string) => {
const target = scope.lookup(name);
if (target === null || target.kind === SyntaxKind.Param) {
return;
}
assert(target.kind === SyntaxKind.LetDeclaration || target.kind === SyntaxKind.SourceFile);
graph.addEdge(source, target, true);
}
2022-09-05 19:38:55 +02:00
switch (node.kind) {
case SyntaxKind.ConstantExpression:
break;
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.addReferencesToGraph(graph, element, source);
}
2022-09-05 19:38:55 +02:00
break;
}
case SyntaxKind.ReferenceExpression:
{
assert(node.name.modulePath.length === 0);
2022-09-07 12:45:38 +02:00
addReference(node.getScope(), node.name.name.text);
2022-09-05 19:38:55 +02:00
break;
}
case SyntaxKind.MemberExpression:
{
this.addReferencesToGraph(graph, node.expression, source);
break;
}
2022-09-05 19:38:55 +02:00
case SyntaxKind.NamedTupleExpression:
{
for (const arg of node.elements) {
this.addReferencesToGraph(graph, arg, source);
}
2022-09-05 19:38:55 +02:00
break;
}
2022-09-07 12:45:38 +02:00
case SyntaxKind.StructExpression:
{
for (const member of node.members) {
switch (member.kind) {
case SyntaxKind.PunnedStructExpressionField:
{
addReference(node.getScope(), node.name.text);
break;
}
case SyntaxKind.StructExpressionField:
{
this.addReferencesToGraph(graph, member.expression, source);
break;
};
}
}
break;
}
2022-09-05 19:38:55 +02:00
case SyntaxKind.NestedExpression:
{
this.addReferencesToGraph(graph, node.expression, source);
2022-09-05 19:38:55 +02:00
break;
}
case SyntaxKind.InfixExpression:
{
this.addReferencesToGraph(graph, node.left, source);
this.addReferencesToGraph(graph, node.right, source);
2022-09-05 19:38:55 +02:00
break;
}
case SyntaxKind.CallExpression:
{
this.addReferencesToGraph(graph, node.func, source);
2022-09-05 19:38:55 +02:00
for (const arg of node.args) {
this.addReferencesToGraph(graph, arg, source);
}
2022-09-05 19:38:55 +02:00
break;
}
case SyntaxKind.IfStatement:
{
for (const cs of node.cases) {
if (cs.test !== null) {
this.addReferencesToGraph(graph, cs.test, source);
}
2022-09-05 19:38:55 +02:00
for (const element of cs.elements) {
this.addReferencesToGraph(graph, element, source);
}
}
2022-09-05 19:38:55 +02:00
break;
}
case SyntaxKind.ExpressionStatement:
{
this.addReferencesToGraph(graph, node.expression, source);
2022-09-05 19:38:55 +02:00
break;
}
2022-09-07 12:45:38 +02:00
2022-09-05 19:38:55 +02:00
case SyntaxKind.ReturnStatement:
{
if (node.expression !== null) {
this.addReferencesToGraph(graph, node.expression, source);
}
2022-09-05 19:38:55 +02:00
break;
}
2022-09-07 12:45:38 +02:00
2022-09-05 19:38:55 +02:00
case SyntaxKind.LetDeclaration:
{
graph.addVertex(node);
if (node.body !== null) {
switch (node.body.kind) {
case SyntaxKind.ExprBody:
{
this.addReferencesToGraph(graph, node.body.expression, node);
2022-09-05 19:38:55 +02:00
break;
}
case SyntaxKind.BlockBody:
{
for (const element of node.body.elements) {
this.addReferencesToGraph(graph, element, node);
}
2022-09-05 19:38:55 +02:00
break;
}
}
}
2022-09-05 19:38:55 +02:00
break;
}
2022-09-05 19:38:55 +02:00
case SyntaxKind.TypeDeclaration:
2022-09-10 16:52:14 +02:00
case SyntaxKind.EnumDeclaration:
2022-09-07 12:45:38 +02:00
case SyntaxKind.StructDeclaration:
break;
2022-09-05 19:38:55 +02:00
default:
throw new Error(`Unexpected ${node.constructor.name}`);
}
2022-09-05 19:38:55 +02:00
}
private completeReferenceGraph(graph: ReferenceGraph, node: Syntax): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.completeReferenceGraph(graph, element);
}
break;
}
case SyntaxKind.LetDeclaration:
{
if (isEmpty(graph.getSourceVertices(node))) {
const source = node.parent!.getScope().node;
assert(source.kind === SyntaxKind.LetDeclaration || source.kind === SyntaxKind.SourceFile);
graph.addEdge(source, node, false);
}
if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) {
for (const element of node.body.elements) {
this.completeReferenceGraph(graph, element);
}
}
break;
}
case SyntaxKind.IfStatement:
case SyntaxKind.ReturnStatement:
case SyntaxKind.ExpressionStatement:
case SyntaxKind.TypeDeclaration:
2022-09-10 16:52:14 +02:00
case SyntaxKind.EnumDeclaration:
2022-09-07 12:45:38 +02:00
case SyntaxKind.StructDeclaration:
break;
default:
throw new Error(`Unexpected ${node}`);
}
}
2022-09-07 12:45:38 +02:00
private initialize(node: Syntax, parentEnv: TypeEnv): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
{
2022-09-07 12:45:38 +02:00
const env = node.typeEnv = new TypeEnv(parentEnv);
for (const element of node.elements) {
2022-09-07 12:45:38 +02:00
this.initialize(element, env);
}
break;
}
case SyntaxKind.LetDeclaration:
{
const env = node.typeEnv = new TypeEnv(parentEnv);
if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) {
for (const element of node.body.elements) {
this.initialize(element, env);
}
}
break;
}
case SyntaxKind.IfStatement:
case SyntaxKind.ExpressionStatement:
case SyntaxKind.ReturnStatement:
2022-09-07 12:45:38 +02:00
break;
2022-09-10 16:52:14 +02:00
case SyntaxKind.EnumDeclaration:
{
const env = node.typeEnv = new TypeEnv(parentEnv);
const constraints = new ConstraintSet();
const typeVars = new TVSet();
const context: InferContext = {
typeVars,
env,
constraints,
returnType: null,
}
this.pushContext(context);
const kindArgs = [];
for (const varExpr of node.varExps) {
const kindArg = this.createTypeVar();
env.add(varExpr.text, new Forall([], [], kindArg));
kindArgs.push(kindArg);
}
let elementTypes: Type[] = [];
const type = new TVariant(node, [], [], node);
if (node.members !== null) {
for (const member of node.members) {
let elementType;
switch (member.kind) {
case SyntaxKind.EnumDeclarationTupleElement:
{
const argTypes = member.elements.map(el => this.inferTypeExpression(el));
elementType = new TArrow(argTypes, TApp.build([ ...kindArgs, type ]));
parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType));
break;
}
// TODO
default:
throw new Error(`Unexpected ${member}`);
}
elementTypes.push(elementType);
}
}
this.popContext(context);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type));
2022-09-10 16:52:14 +02:00
break;
}
case SyntaxKind.TypeDeclaration:
{
const env = node.typeEnv = new TypeEnv(parentEnv);
const constraints = new ConstraintSet();
const typeVars = new TVSet();
const context: InferContext = {
constraints,
typeVars,
env,
returnType: null,
};
this.pushContext(context);
for (const varExpr of node.typeVars) {
env.add(varExpr.text, new Forall([], [], this.createTypeVar()));
}
const type = this.inferTypeExpression(node.typeExpression);
this.popContext(context);
const scheme = new Forall(typeVars, constraints, type);
parentEnv.add(node.name.text, scheme);
node.scheme = scheme;
break;
}
case SyntaxKind.StructDeclaration:
2022-09-07 12:45:38 +02:00
{
2022-09-10 16:52:14 +02:00
const env = node.typeEnv = new TypeEnv(parentEnv);
const typeVars = new TVSet();
const constraints = new ConstraintSet();
const context: InferContext = {
constraints,
typeVars,
env,
returnType: null,
};
this.pushContext(context);
const argTypes = [];
2022-09-10 16:52:14 +02:00
for (const varExpr of node.typeVars) {
const type = this.createTypeVar();
env.add(varExpr.text, new Forall([], [], type));
argTypes.push(type);
2022-09-10 16:52:14 +02:00
}
2022-09-07 12:45:38 +02:00
const fields = new Map<string, Type>();
if (node.members !== null) {
for (const member of node.members) {
fields.set(member.name.text, this.inferTypeExpression(member.typeExpr));
}
}
2022-09-10 16:52:14 +02:00
this.popContext(context);
const type = new TRecord(node, argTypes, fields, node);
const scheme = new Forall(typeVars, constraints, type);
parentEnv.add(node.name.text, scheme);
node.tvs = argTypes;
node.scheme = scheme; //new Forall(typeVars, constraints, new TApp(type, argTypes));
break;
2022-09-07 12:45:38 +02:00
}
default:
throw new Error(`Unexpected ${node}`);
}
}
public check(node: SourceFile): void {
const kenv = new KindEnv();
kenv.setNamed('Int', new KStar());
kenv.setNamed('String', new KStar());
kenv.setNamed('Bool', new KStar());
this.forwardDeclareKind(node, kenv);
this.inferKind(node, kenv);
const typeVars = new TVSet();
const constraints = new ConstraintSet();
const env = new TypeEnv();
const context: InferContext = { typeVars, constraints, env, returnType: null };
this.pushContext(context);
const a = this.createTypeVar();
const b = this.createTypeVar();
const f = this.createTypeVar();
env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b)));
env.add('String', new Forall([], [], this.stringType));
env.add('Int', new Forall([], [], this.intType));
env.add('True', new Forall([], [], this.boolType));
env.add('False', new Forall([], [], this.boolType));
env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)));
env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)));
const graph = new LabeledDirectedHashGraph<NodeWithBindings, boolean>();
this.addReferencesToGraph(graph, node, node);
this.completeReferenceGraph(graph, node);
this.initialize(node, env);
2022-09-07 12:45:38 +02:00
this.pushContext({
typeVars,
constraints,
env: node.typeEnv!,
returnType: null
});
const sccs = [...strongconnect(graph)];
for (const nodes of sccs) {
if (nodes.some(n => n.kind === SyntaxKind.SourceFile)) {
assert(nodes.length === 1);
continue;
}
const typeVars = new TVSet();
const constraints = new ConstraintSet();
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
if (!isFunctionDeclarationLike(node)) {
continue;
}
const env = node.typeEnv!;
const context: InferContext = {
typeVars,
constraints,
env,
returnType: null,
};
node.context = context;
this.contexts.push(context);
const returnType = this.createTypeVar();
context.returnType = returnType;
const paramTypes = [];
for (const param of node.params) {
2022-09-07 12:45:38 +02:00
const paramType = this.inferBindings(param.pattern, [], []);
paramTypes.push(paramType);
}
let type = new TArrow(paramTypes, returnType);
if (node.typeAssert !== null) {
this.addConstraint(
new CEqual(
this.inferTypeExpression(node.typeAssert.typeExpression),
type,
node
)
);
}
node.type = type;
this.contexts.pop();
2022-09-07 12:45:38 +02:00
// FIXME get rid of all this useless stack manipulation
const parentDecl = node.parent!.getScope().node;
const bindCtx = {
typeVars: context.typeVars,
constraints: context.constraints,
env: parentDecl.typeEnv!,
returnType: null,
};
this.contexts.push(bindCtx)
let ty2;
if (node.pattern.kind === SyntaxKind.WrappedOperator) {
ty2 = this.createTypeVar();
this.addBinding(node.pattern.operator.text, new Forall([], [], ty2));
} else {
ty2 = this.inferBindings(node.pattern, typeVars, constraints);
}
2022-09-07 12:45:38 +02:00
this.addConstraint(new CEqual(ty2, type, node));
this.contexts.pop();
}
}
const visitElements = (elements: Syntax[]) => {
for (const element of elements) {
if (element.kind === SyntaxKind.LetDeclaration
&& isFunctionDeclarationLike(element)
&& graph.hasEdge(node, element, false)) {
assert(element.pattern.kind === SyntaxKind.BindPattern);
const scheme = this.lookup(element.pattern.name.text);
assert(scheme !== null);
this.instantiate(scheme, null);
} else {
this.infer(element);
}
}
}
for (const nodes of sccs) {
if (nodes.some(n => n.kind === SyntaxKind.SourceFile)) {
assert(nodes.length === 1);
continue;
}
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
node.active = true;
}
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
if (!isFunctionDeclarationLike(node)) {
continue;
}
const context = node.context!;
const returnType = context.returnType!;
this.contexts.push(context);
if (node.body !== null) {
switch (node.body.kind) {
case SyntaxKind.ExprBody:
{
this.addConstraint(
new CEqual(
this.inferExpression(node.body.expression),
returnType,
node.body.expression
)
);
break;
}
case SyntaxKind.BlockBody:
{
visitElements(node.body.elements);
break;
}
}
}
this.contexts.pop();
}
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
node.active = false;
}
}
visitElements(node.elements);
2022-09-07 12:45:38 +02:00
this.contexts.pop();
this.popContext(context);
this.solve(new CMany(constraints), this.solution);
}
private solve(constraint: Constraint, solution: TVSub): void {
const queue = [ constraint ];
let errorCount = 0;
while (queue.length > 0) {
const constraint = queue.shift()!;
switch (constraint.kind) {
case ConstraintKind.Many:
{
for (const element of constraint.elements) {
queue.push(element);
}
break;
}
case ConstraintKind.Equal:
{
// constraint.dump();
const unify = (left: Type, right: Type): boolean => {
const find = (type: Type): Type => {
while (type.kind === TypeKind.Var && solution.has(type)) {
type = solution.get(type)!;
}
return type;
}
left = find(left);
right = find(right);
if (left.kind === TypeKind.Var) {
if (right.hasTypeVar(left)) {
// TODO occurs check diagnostic
return false;
}
solution.set(left, right);
TypeBase.join(left, right);
return true;
}
if (right.kind === TypeKind.Var) {
return unify(right, left);
}
if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) {
if (left.paramTypes.length !== right.paramTypes.length) {
this.diagnostics.add(new ArityMismatchDiagnostic(left, right));
return false;
}
let success = true;
const count = left.paramTypes.length;
for (let i = 0; i < count; i++) {
if (!unify(left.paramTypes[i], right.paramTypes[i])) {
success = false;
}
}
if (!unify(left.returnType, right.returnType)) {
success = false;
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) {
return unify(left.returnType, right);
}
if (right.kind === TypeKind.Arrow) {
return unify(right, left);
}
if (left.kind === TypeKind.Con && right.kind === TypeKind.Con) {
if (left.id === right.id) {
assert(left.argTypes.length === right.argTypes.length);
const count = left.argTypes.length;
let success = true;
for (let i = 0; i < count; i++) {
if (!unify(left.argTypes[i], right.argTypes[i])) {
success = false;
}
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
}
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) {
let success = false;
// This works like an ordinary union-find algorithm where an additional
// property 'fields' is carried over from the child nodes to the
// ever-changing root node.
const root = left.find();
right.parent = root;
if (root.fields === undefined) {
root.fields = new Map([ [ root.name, root.type ] ]);
}
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of right.fields) {
if (root.fields.has(fieldName)) {
if (!unify(root.fields.get(fieldName)!, fieldType)) {
success = false;
}
} else {
root.fields.set(fieldName, fieldType);
}
}
delete right.fields;
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Variant && right.kind === TypeKind.Variant) {
if (left.decl !== right.decl) {
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
return false;
}
return true;
}
if (left.kind === TypeKind.App && right.kind === TypeKind.App) {
return unify(left.left, right.left)
&& unify(left.right, right.right);
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) {
if (left.decl !== right.decl) {
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
return false;
}
let success = true;
const remaining = new Set(right.fields.keys());
for (const [fieldName, fieldType] of left.fields) {
if (right.fields.has(fieldName)) {
if (!unify(fieldType, right.fields.get(fieldName)!)) {
success = false;
}
remaining.delete(fieldName);
} else {
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName, constraint.node));
success = false;
}
}
for (const fieldName of remaining) {
this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName, constraint.node));
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let success = true;
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of right.fields) {
if (left.fields.has(fieldName)) {
if (!unify(fieldType, left.fields.get(fieldName)!)) {
success = false;
}
} else {
this.diagnostics.add(new FieldMissingDiagnostic(left, fieldName, constraint.node));
}
}
if (success) {
TypeBase.join(left, right);
}
return success;
}
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Record) {
return unify(right, left);
}
2022-09-07 12:45:38 +02:00
this.diagnostics.add(
new UnificationFailedDiagnostic(
left.substitute(solution),
right.substitute(solution),
[...constraint.getNodes()],
)
);
return false;
2022-09-07 12:45:38 +02:00
}
if (!unify(constraint.left, constraint.right)) {
errorCount++;
if (errorCount === MAX_TYPE_ERROR_COUNT) {
return;
}
2022-09-07 12:45:38 +02:00
}
break;
2022-09-07 12:45:38 +02:00
}
2022-09-07 12:45:38 +02:00
}
}
}
}