From 4cc2b23109b55678b6999872d3d16cbbf3c309ea Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Wed, 14 Sep 2022 16:46:30 +0200 Subject: [PATCH] Add experimental support for kind inference --- src/checker.ts | 373 ++++++++++++++++++++++++++++++++++++++++++++- src/cst.ts | 35 ++++- src/diagnostics.ts | 32 ++-- src/parser.ts | 40 ++++- 4 files changed, 454 insertions(+), 26 deletions(-) diff --git a/src/checker.ts b/src/checker.ts index 0e96415c1..86f739d26 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -4,6 +4,7 @@ import { Pattern, Scope, SourceFile, + StructDeclaration, Symkind, Syntax, SyntaxKind, @@ -459,6 +460,95 @@ function isKindedType(type: Type): type is KindedType { || type.kind === TypeKind.Record; } +export const enum KindType { + Star, + Arrow, + Var, +} + +class KVSub { + + private mapping = new Map(); + + public set(kv: KVar, kind: Kind): void { + this.mapping.set(kv.id, kind); + } + + public get(kv: KVar): Kind | undefined { + return this.mapping.get(kv.id); + } + + public has(kv: KVar): boolean { + return this.mapping.has(kv.id); + } + + public values(): Iterable { + return this.mapping.values(); + } + +} + +abstract class KindBase { + + public abstract readonly type: KindType; + + public abstract substitute(sub: KVSub): Kind; + +} + +class KVar extends KindBase { + + public readonly type = KindType.Var; + + public constructor( + public id: number, + ) { + super(); + } + + public substitute(sub: KVSub): Kind { + const other = sub.get(this); + return other === undefined + ? this : other.substitute(sub); + } + +} + +class KStar extends KindBase { + + public readonly type = KindType.Star; + + public substitute(_sub: KVSub): Kind { + return this; + } + +} + +class KArrow extends KindBase { + + public readonly type = KindType.Arrow; + + public constructor( + public left: Kind, + public right: Kind, + ) { + super(); + } + + public substitute(sub: KVSub): Kind { + return new KArrow( + this.left.substitute(sub), + this.right.substitute(sub), + ); + } + +} + +export type Kind + = KStar + | KArrow + | KVar + class TVSet { private mapping = new Map(); @@ -663,6 +753,43 @@ export class TypeEnv { } +class KindEnv { + + private mapping1 = new Map(); + private mapping2 = new Map(); + + public constructor(public parent: KindEnv | null = null) { + + } + + public setNamed(name: string, kind: Kind): void { + assert(!this.mapping1.has(name)); + this.mapping1.set(name, kind); + } + + public setVar(tv: TVar, kind: Kind): void { + assert(!this.mapping2.has(tv.id)); + this.mapping2.set(tv.id, kind); + } + + public lookupNamed(name: string): Kind | null { + let curr: KindEnv | null = this; + do { + const kind = curr.mapping1.get(name); + if (kind !== undefined) { + return kind; + } + curr = curr.parent; + } while (curr !== null); + return null; + } + + public lookupVar(tv: TVar): Kind | null { + return this.mapping2.get(tv.id) ?? null; + } + +} + export interface InferContext { typeVars: TVSet; env: TypeEnv; @@ -678,11 +805,9 @@ function isFunctionDeclarationLike(node: LetDeclaration): boolean { export class Checker { private nextTypeVarId = 0; + private nextKindVarId = 0; private nextConTypeId = 0; - //private graph?: Graph; - //private currentCycle?: Map; - private stringType = new TCon(this.nextConTypeId++, [], 'String'); private intType = new TCon(this.nextConTypeId++, [], 'Int'); private boolType = new TCon(this.nextConTypeId++, [], 'Bool'); @@ -690,6 +815,7 @@ export class Checker { private contexts: InferContext[] = []; private solution = new TVSub(); + private kindSolution = new KVSub(); public constructor( private diagnostics: Diagnostics @@ -763,7 +889,221 @@ export class Checker { context.env.add(name, scheme); } - public infer(node: Syntax): void { + private inferKindFromTypeExpression(node: TypeExpression, env: KindEnv): Kind { + switch (node.kind) { + case SyntaxKind.VarTypeExpression: + case SyntaxKind.ReferenceTypeExpression: + { + const kind = env.lookupNamed(node.name.text); + if (kind === null) { + this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); + // Create a filler kind variable that still will be able to catch other errors. + return this.createKindVar(); + } + return kind; + } + case SyntaxKind.AppTypeExpression: + { + let operator = this.inferKindFromTypeExpression(node.operator, env); + const args = node.args.map(arg => this.inferKindFromTypeExpression(arg, env)); + let result = operator; + for (const arg of args) { + result = this.applyKind(result, arg, node); + } + return result; + } + case SyntaxKind.NestedTypeExpression: + { + return this.inferKindFromTypeExpression(node.typeExpr, env); + } + default: + throw new Error(`Unexpected ${node}`); + } + } + + private createKindVar(): KVar { + return new KVar(this.nextKindVarId++); + } + + private applyKind(operator: Kind, arg: Kind, node: Syntax): Kind { + switch (operator.type) { + case KindType.Var: + { + const a1 = this.createKindVar(); + const a2 = this.createKindVar(); + const arrow = new KArrow(a1, a2); + this.unifyKind(arrow, operator, node); + this.unifyKind(a1, arg, node); + return a2; + } + case KindType.Arrow: + { + // Unify the argument to the operator's argument kind and return + // whatever the operator returns. + this.unifyKind(operator.left, arg, node); + return operator.right; + } + case KindType.Star: + { + this.diagnostics.add( + new KindMismatchDiagnostic( + operator, + new KArrow( + this.createKindVar(), + this.createKindVar() + ), + node + ) + ); + // Create a filler kind variable that still will be able to catch other errors. + return this.createKindVar(); + } + } + } + + private forwardDeclareKind(node: Syntax, env: KindEnv): void { + + switch (node.kind) { + + case SyntaxKind.SourceFile: + { + for (const element of node.elements) { + this.forwardDeclareKind(element, env); + } + break; + } + + case SyntaxKind.StructDeclaration: + case SyntaxKind.EnumDeclaration: + { + env.setNamed(node.name.text, this.createKindVar()); + if (node.members !== null) { + for (const member of node.members) { + env.setNamed(member.name.text, this.createKindVar()); + } + } + break; + } + + } + + } + + private inferKind(node: Syntax, env: KindEnv): void { + switch (node.kind) { + case SyntaxKind.SourceFile: + { + for (const element of node.elements) { + this.inferKind(element, env); + } + break; + } + case SyntaxKind.StructDeclaration: + { + // TODO + break; + } + case SyntaxKind.EnumDeclaration: + { + const declKind = env.lookupNamed(node.name.text)!; + const innerEnv = new KindEnv(env); + let kind: Kind = new KStar(); + // FIXME should I go from right to left or left to right? + for (let i = node.varExps.length-1; i >= 0; i--) { + const varExpr = node.varExps[i]; + const paramKind = this.createKindVar(); + innerEnv.setNamed(varExpr.text, paramKind); + kind = new KArrow(paramKind, kind); + } + this.unifyKind(declKind, kind, node); + if (node.members !== null) { + for (const member of node.members) { + switch (member.kind) { + case SyntaxKind.EnumDeclarationTupleElement: + { + for (const element of member.elements) { + this.unifyKind(this.inferKindFromTypeExpression(element, innerEnv), new KStar(), element); + } + break; + } + // TODO + } + } + } + break; + } + } + } + + private unifyKind(a: Kind, b: Kind, node: Syntax): boolean { + + const find = (kind: Kind): Kind => { + let curr = kind; + while (curr.type === KindType.Var && this.kindSolution.has(curr)) { + curr = this.kindSolution.get(curr)!; + } + // if (kind.type === KindType.Var && ) { + // this.kindSolution.set(kind.id, curr); + // } + return curr; + } + + const solve = (kind: Kind) => kind.substitute(this.kindSolution); + + a = find(a); + b = find(b); + + if (a.type === KindType.Var) { + this.kindSolution.set(a, b); + return true; + } + + if (b.type === KindType.Var) { + return this.unifyKind(b, a, node); + } + + if (a.type === KindType.Star && b.type === KindType.Star) { + return true; + } + + if (a.type === KindType.Arrow && b.type === KindType.Arrow) { + return this.unifyKind(a.left, b.left, node) + || this.unifyKind(a.right, b.right, node); + // let success = true; + // const leftStack = []; + // const rightStack = []; + // let leftCurr: Kind = a; + // let rightCurr: Kind = b; + // for (;;) { + // while (leftCurr.type === KindType.Arrow) { + // leftStack.push(leftCurr); + // leftCurr = find(leftCurr.left); + // } + // while (rightCurr.type === KindType.Arrow) { + // rightStack.push(rightCurr); + // rightCurr = find(rightCurr.left); + // } + // if (!this.unifyKind(leftCurr, rightCurr, node)) { + // success = false; + // } + // if (leftStack.length === 0 || rightStack.length === 0) { + // if (leftStack.length > 0 || rightStack.length > 0) { + // this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node)); + // success = false; + // } + // break; + // } + // rightCurr = find(rightStack.pop()!.right); + // leftCurr = find(leftStack.pop()!.right); + // } + // return success; + } + + this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node)); + return false; + } + + private infer(node: Syntax): void { switch (node.kind) { @@ -1010,7 +1350,20 @@ export class Checker { throw new Error(`Unexpected ${member}`); } } - const type = new TRecord(decl, argTypes, fields, node); + let type = new TRecord(decl, argTypes, fields, node); + if (decl.kind === SyntaxKind.EnumDeclarationStructElement) { + const elementTypes = []; + for (const element of decl.parent!.elements) { + let elementType; + if (element === decl) { + elementType = type; + } else { + elementType = this.createTypeVar(); + } + elementTypes.push(elementType); + } + type = new TVariant(typeVars, elementTypes); + } this.addConstraint( new CEqual( declType, @@ -1467,6 +1820,13 @@ export class Checker { public check(node: SourceFile): void { + const kenv = new KindEnv(); + kenv.setNamed('Int', new KStar()); + kenv.setNamed('String', new KStar()); + kenv.setNamed('Bool', new KStar()); + this.forwardDeclareKind(node, kenv); + this.inferKind(node, kenv); + const typeVars = new TVSet(); const constraints = new ConstraintSet(); const env = new TypeEnv(); @@ -1475,7 +1835,10 @@ export class Checker { this.pushContext(context); const a = this.createTypeVar(); + const b = this.createTypeVar(); + const f = this.createTypeVar(); + env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b))); env.add('String', new Forall([], [], this.stringType)); env.add('Int', new Forall([], [], this.intType)); env.add('True', new Forall([], [], this.boolType)); diff --git a/src/cst.ts b/src/cst.ts index 0163b3ab6..6b14f5a1a 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -109,6 +109,7 @@ export const enum SyntaxKind { ArrowTypeExpression, VarTypeExpression, AppTypeExpression, + NestedTypeExpression, // Patterns BindPattern, @@ -257,9 +258,17 @@ export class Scope { break; } case SyntaxKind.EnumDeclaration: + { + this.add(node.name.text, node, Symkind.Type); + if (node.members !== null) { + for (const member of node.members) { + this.add(member.name.text, member, Symkind.Constructor); + } + } + } case SyntaxKind.StructDeclaration: { - this.add(node.name.text, node, Symkind.Constructor); + this.add(node.name.text, node, Symkind.Constructor | Symkind.Type); break; } case SyntaxKind.LetDeclaration: @@ -1032,11 +1041,34 @@ export class VarTypeExpression extends SyntaxBase { } +export class NestedTypeExpression extends SyntaxBase { + + public readonly kind = SyntaxKind.NestedTypeExpression; + + public constructor( + public lparen: LParen, + public typeExpr: TypeExpression, + public rparen: RParen, + ) { + super(); + } + + public getFirstToken(): Token { + return this.lparen; + } + + public getLastToken(): Token { + return this.rparen; + } + +} + export type TypeExpression = ReferenceTypeExpression | ArrowTypeExpression | VarTypeExpression | AppTypeExpression + | NestedTypeExpression export class BindPattern extends SyntaxBase { @@ -1730,6 +1762,7 @@ export class EnumDeclaration extends SyntaxBase { public pubKeyword: PubKeyword | null, public enumKeyword: EnumKeyword, public name: IdentifierAlt, + public varExps: Identifier[], public members: EnumDeclarationElement[] | null, ) { super(); diff --git a/src/diagnostics.ts b/src/diagnostics.ts index 71856d871..e745ca5c3 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -1,6 +1,6 @@ import { describe } from "yargs"; -import { TypeKind, type Type, type TArrow, TRecord } from "./checker"; +import { TypeKind, type Type, type TArrow, TRecord, Kind, KindType } from "./checker"; import { Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst"; import { countDigits, IndentWriter } from "./util"; @@ -70,6 +70,10 @@ const DESCRIPTIONS: Partial> = { [SyntaxKind.RBrace]: "'}'", [SyntaxKind.LBracket]: "'['", [SyntaxKind.RBracket]: "']'", + [SyntaxKind.StructKeyword]: "'struct'", + [SyntaxKind.EnumKeyword]: "'enum'", + [SyntaxKind.MatchKeyword]: "'match'", + [SyntaxKind.TypeKeyword]: "'type'", [SyntaxKind.IdentifierAlt]: 'an identifier starting with an uppercase letter', [SyntaxKind.ConstantExpression]: 'a constant expression', [SyntaxKind.ReferenceExpression]: 'a reference expression', @@ -196,6 +200,7 @@ export function describeType(type: Type): string { } return out; } + case TypeKind.Variant: case TypeKind.Record: { return type.decl.name.text; @@ -220,6 +225,17 @@ export function describeType(type: Type): string { } } +function describeKind(kind: Kind): string { + switch (kind.type) { + case KindType.Var: + return `a${kind.id}`; + case KindType.Arrow: + return describeKind(kind.left) + ' -> ' + describeKind(kind.right); + case KindType.Star: + return '*'; + } +} + function getFirstNodeInTypeChain(type: Type): Syntax | null { let curr = type.next; while (curr !== type && (curr.kind === TypeKind.Var || curr.node === null)) { @@ -344,8 +360,8 @@ export class KindMismatchDiagnostic { public readonly level = Level.Error; public constructor( - public leftSize: number, - public rightSize: number, + public left: Kind, + public right: Kind, public node: Syntax | null, ) { @@ -353,15 +369,7 @@ export class KindMismatchDiagnostic { public format(out: IndentWriter): void { out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET); - out.write(`kind `); - for (let i = 0; i < this.leftSize-1; i++) { - out.write(`* -> `); - } - out.write(`* does not match with `); - for (let i = 0; i < this.rightSize-1; i++) { - out.write(`* -> `); - } - out.write(`*\n\n`); + out.write(`kind ${describeKind(this.left)} does not match with ${describeKind(this.right)}\n\n`); if (this.node !== null) { out.write(printNode(this.node) + '\n'); } diff --git a/src/parser.ts b/src/parser.ts index 27b269a43..8c725b8f4 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -54,6 +54,8 @@ import { VarTypeExpression, TypeDeclaration, AppTypeExpression, + NestedPattern, + NestedTypeExpression, } from "./cst" import { Stream } from "./util"; @@ -176,6 +178,13 @@ export class Parser { this.getToken(); return new VarTypeExpression(t0); } + case SyntaxKind.LParen: + { + this.getToken(); + const typeExpr = this.parseTypeExpression(); + const t2 = this.expectToken(SyntaxKind.RParen); + return new NestedTypeExpression(t0, typeExpr, t2); + } case SyntaxKind.IdentifierAlt: return this.parseReferenceTypeExpression(); default: @@ -477,11 +486,15 @@ export class Parser { this.raiseParseError(t0, [ SyntaxKind.EnumKeyword ]); } const name = this.expectToken(SyntaxKind.IdentifierAlt); - const t1 = this.peekToken(); + let t1 = this.getToken(); + const varExps = []; + while (t1.kind === SyntaxKind.Identifier) { + varExps.push(t1); + t1 = this.getToken(); + } let members = null; if (t1.kind === SyntaxKind.BlockStart) { members = []; - this.getToken(); for (;;) { const t2 = this.peekToken(); if (t2.kind === SyntaxKind.BlockEnd) { @@ -498,6 +511,7 @@ export class Parser { const name = this.expectToken(SyntaxKind.Identifier); const colon = this.expectToken(SyntaxKind.Colon); const typeExpr = this.parseTypeExpression(); + this.expectToken(SyntaxKind.LineFoldEnd); members.push(new StructDeclarationField(name, colon, typeExpr)); const t4 = this.peekToken(); if (t4.kind === SyntaxKind.BlockEnd) { @@ -521,9 +535,12 @@ export class Parser { members.push(member); this.expectToken(SyntaxKind.LineFoldEnd); } + t1 = this.getToken(); } - this.expectToken(SyntaxKind.LineFoldEnd); - return new EnumDeclaration(pubKeyword, t0, name, members); + if (t1.kind !== SyntaxKind.LineFoldEnd) { + this.raiseParseError(t1, [ SyntaxKind.Identifier, SyntaxKind.BlockStart, SyntaxKind.LineFoldEnd ]); + } + return new EnumDeclaration(pubKeyword, t0, name, varExps, members); } public parseStructDeclaration(): StructDeclaration { @@ -651,19 +668,26 @@ export class Parser { switch (t0.kind) { case SyntaxKind.LParen: { - const t1 = this.peekToken(); + const t1 = this.peekToken(2); if (t1.kind === SyntaxKind.IdentifierAlt) { this.getToken(); - return this.parsePatternStartingWithConstructor(); + const pattern = this.parsePatternStartingWithConstructor(); + const t3 = this.expectToken(SyntaxKind.RParen); + return new NestedPattern(t0, pattern, t3); } else { return this.parseTuplePattern(); } } case SyntaxKind.IdentifierAlt: - return this.parsePatternStartingWithConstructor(); + { + this.getToken(); + return new NamedTuplePattern(t0, []); + } case SyntaxKind.Identifier: + { this.getToken(); return new BindPattern(t0); + } default: this.raiseParseError(t0, [ SyntaxKind.Identifier ]); } @@ -866,7 +890,7 @@ export class Parser { case SyntaxKind.StructKeyword: return this.parseStructDeclaration(); case SyntaxKind.EnumKeyword: - return this.parseStructDeclaration(); + return this.parseEnumDeclaration(); case SyntaxKind.TypeKeyword: return this.parseTypeDeclaration(); case SyntaxKind.IfKeyword: