Add experimental support for kind inference

This commit is contained in:
Sam Vervaeck 2022-09-14 16:46:30 +02:00
parent 0eada4068c
commit 4cc2b23109
4 changed files with 454 additions and 26 deletions

View file

@ -4,6 +4,7 @@ import {
Pattern, Pattern,
Scope, Scope,
SourceFile, SourceFile,
StructDeclaration,
Symkind, Symkind,
Syntax, Syntax,
SyntaxKind, SyntaxKind,
@ -459,6 +460,95 @@ function isKindedType(type: Type): type is KindedType {
|| type.kind === TypeKind.Record; || type.kind === TypeKind.Record;
} }
export const enum KindType {
Star,
Arrow,
Var,
}
class KVSub {
private mapping = new Map<number, Kind>();
public set(kv: KVar, kind: Kind): void {
this.mapping.set(kv.id, kind);
}
public get(kv: KVar): Kind | undefined {
return this.mapping.get(kv.id);
}
public has(kv: KVar): boolean {
return this.mapping.has(kv.id);
}
public values(): Iterable<Kind> {
return this.mapping.values();
}
}
abstract class KindBase {
public abstract readonly type: KindType;
public abstract substitute(sub: KVSub): Kind;
}
class KVar extends KindBase {
public readonly type = KindType.Var;
public constructor(
public id: number,
) {
super();
}
public substitute(sub: KVSub): Kind {
const other = sub.get(this);
return other === undefined
? this : other.substitute(sub);
}
}
class KStar extends KindBase {
public readonly type = KindType.Star;
public substitute(_sub: KVSub): Kind {
return this;
}
}
class KArrow extends KindBase {
public readonly type = KindType.Arrow;
public constructor(
public left: Kind,
public right: Kind,
) {
super();
}
public substitute(sub: KVSub): Kind {
return new KArrow(
this.left.substitute(sub),
this.right.substitute(sub),
);
}
}
export type Kind
= KStar
| KArrow
| KVar
class TVSet { class TVSet {
private mapping = new Map<number, TVar>(); private mapping = new Map<number, TVar>();
@ -663,6 +753,43 @@ export class TypeEnv {
} }
class KindEnv {
private mapping1 = new Map<string, Kind>();
private mapping2 = new Map<number, Kind>();
public constructor(public parent: KindEnv | null = null) {
}
public setNamed(name: string, kind: Kind): void {
assert(!this.mapping1.has(name));
this.mapping1.set(name, kind);
}
public setVar(tv: TVar, kind: Kind): void {
assert(!this.mapping2.has(tv.id));
this.mapping2.set(tv.id, kind);
}
public lookupNamed(name: string): Kind | null {
let curr: KindEnv | null = this;
do {
const kind = curr.mapping1.get(name);
if (kind !== undefined) {
return kind;
}
curr = curr.parent;
} while (curr !== null);
return null;
}
public lookupVar(tv: TVar): Kind | null {
return this.mapping2.get(tv.id) ?? null;
}
}
export interface InferContext { export interface InferContext {
typeVars: TVSet; typeVars: TVSet;
env: TypeEnv; env: TypeEnv;
@ -678,11 +805,9 @@ function isFunctionDeclarationLike(node: LetDeclaration): boolean {
export class Checker { export class Checker {
private nextTypeVarId = 0; private nextTypeVarId = 0;
private nextKindVarId = 0;
private nextConTypeId = 0; private nextConTypeId = 0;
//private graph?: Graph<Syntax>;
//private currentCycle?: Map<Syntax, Type>;
private stringType = new TCon(this.nextConTypeId++, [], 'String'); private stringType = new TCon(this.nextConTypeId++, [], 'String');
private intType = new TCon(this.nextConTypeId++, [], 'Int'); private intType = new TCon(this.nextConTypeId++, [], 'Int');
private boolType = new TCon(this.nextConTypeId++, [], 'Bool'); private boolType = new TCon(this.nextConTypeId++, [], 'Bool');
@ -690,6 +815,7 @@ export class Checker {
private contexts: InferContext[] = []; private contexts: InferContext[] = [];
private solution = new TVSub(); private solution = new TVSub();
private kindSolution = new KVSub();
public constructor( public constructor(
private diagnostics: Diagnostics private diagnostics: Diagnostics
@ -763,7 +889,221 @@ export class Checker {
context.env.add(name, scheme); context.env.add(name, scheme);
} }
public infer(node: Syntax): void { private inferKindFromTypeExpression(node: TypeExpression, env: KindEnv): Kind {
switch (node.kind) {
case SyntaxKind.VarTypeExpression:
case SyntaxKind.ReferenceTypeExpression:
{
const kind = env.lookupNamed(node.name.text);
if (kind === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
// Create a filler kind variable that still will be able to catch other errors.
return this.createKindVar();
}
return kind;
}
case SyntaxKind.AppTypeExpression:
{
let operator = this.inferKindFromTypeExpression(node.operator, env);
const args = node.args.map(arg => this.inferKindFromTypeExpression(arg, env));
let result = operator;
for (const arg of args) {
result = this.applyKind(result, arg, node);
}
return result;
}
case SyntaxKind.NestedTypeExpression:
{
return this.inferKindFromTypeExpression(node.typeExpr, env);
}
default:
throw new Error(`Unexpected ${node}`);
}
}
private createKindVar(): KVar {
return new KVar(this.nextKindVarId++);
}
private applyKind(operator: Kind, arg: Kind, node: Syntax): Kind {
switch (operator.type) {
case KindType.Var:
{
const a1 = this.createKindVar();
const a2 = this.createKindVar();
const arrow = new KArrow(a1, a2);
this.unifyKind(arrow, operator, node);
this.unifyKind(a1, arg, node);
return a2;
}
case KindType.Arrow:
{
// Unify the argument to the operator's argument kind and return
// whatever the operator returns.
this.unifyKind(operator.left, arg, node);
return operator.right;
}
case KindType.Star:
{
this.diagnostics.add(
new KindMismatchDiagnostic(
operator,
new KArrow(
this.createKindVar(),
this.createKindVar()
),
node
)
);
// Create a filler kind variable that still will be able to catch other errors.
return this.createKindVar();
}
}
}
private forwardDeclareKind(node: Syntax, env: KindEnv): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.forwardDeclareKind(element, env);
}
break;
}
case SyntaxKind.StructDeclaration:
case SyntaxKind.EnumDeclaration:
{
env.setNamed(node.name.text, this.createKindVar());
if (node.members !== null) {
for (const member of node.members) {
env.setNamed(member.name.text, this.createKindVar());
}
}
break;
}
}
}
private inferKind(node: Syntax, env: KindEnv): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.inferKind(element, env);
}
break;
}
case SyntaxKind.StructDeclaration:
{
// TODO
break;
}
case SyntaxKind.EnumDeclaration:
{
const declKind = env.lookupNamed(node.name.text)!;
const innerEnv = new KindEnv(env);
let kind: Kind = new KStar();
// FIXME should I go from right to left or left to right?
for (let i = node.varExps.length-1; i >= 0; i--) {
const varExpr = node.varExps[i];
const paramKind = this.createKindVar();
innerEnv.setNamed(varExpr.text, paramKind);
kind = new KArrow(paramKind, kind);
}
this.unifyKind(declKind, kind, node);
if (node.members !== null) {
for (const member of node.members) {
switch (member.kind) {
case SyntaxKind.EnumDeclarationTupleElement:
{
for (const element of member.elements) {
this.unifyKind(this.inferKindFromTypeExpression(element, innerEnv), new KStar(), element);
}
break;
}
// TODO
}
}
}
break;
}
}
}
private unifyKind(a: Kind, b: Kind, node: Syntax): boolean {
const find = (kind: Kind): Kind => {
let curr = kind;
while (curr.type === KindType.Var && this.kindSolution.has(curr)) {
curr = this.kindSolution.get(curr)!;
}
// if (kind.type === KindType.Var && ) {
// this.kindSolution.set(kind.id, curr);
// }
return curr;
}
const solve = (kind: Kind) => kind.substitute(this.kindSolution);
a = find(a);
b = find(b);
if (a.type === KindType.Var) {
this.kindSolution.set(a, b);
return true;
}
if (b.type === KindType.Var) {
return this.unifyKind(b, a, node);
}
if (a.type === KindType.Star && b.type === KindType.Star) {
return true;
}
if (a.type === KindType.Arrow && b.type === KindType.Arrow) {
return this.unifyKind(a.left, b.left, node)
|| this.unifyKind(a.right, b.right, node);
// let success = true;
// const leftStack = [];
// const rightStack = [];
// let leftCurr: Kind = a;
// let rightCurr: Kind = b;
// for (;;) {
// while (leftCurr.type === KindType.Arrow) {
// leftStack.push(leftCurr);
// leftCurr = find(leftCurr.left);
// }
// while (rightCurr.type === KindType.Arrow) {
// rightStack.push(rightCurr);
// rightCurr = find(rightCurr.left);
// }
// if (!this.unifyKind(leftCurr, rightCurr, node)) {
// success = false;
// }
// if (leftStack.length === 0 || rightStack.length === 0) {
// if (leftStack.length > 0 || rightStack.length > 0) {
// this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node));
// success = false;
// }
// break;
// }
// rightCurr = find(rightStack.pop()!.right);
// leftCurr = find(leftStack.pop()!.right);
// }
// return success;
}
this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node));
return false;
}
private infer(node: Syntax): void {
switch (node.kind) { switch (node.kind) {
@ -1010,7 +1350,20 @@ export class Checker {
throw new Error(`Unexpected ${member}`); throw new Error(`Unexpected ${member}`);
} }
} }
const type = new TRecord(decl, argTypes, fields, node); let type = new TRecord(decl, argTypes, fields, node);
if (decl.kind === SyntaxKind.EnumDeclarationStructElement) {
const elementTypes = [];
for (const element of decl.parent!.elements) {
let elementType;
if (element === decl) {
elementType = type;
} else {
elementType = this.createTypeVar();
}
elementTypes.push(elementType);
}
type = new TVariant(typeVars, elementTypes);
}
this.addConstraint( this.addConstraint(
new CEqual( new CEqual(
declType, declType,
@ -1467,6 +1820,13 @@ export class Checker {
public check(node: SourceFile): void { public check(node: SourceFile): void {
const kenv = new KindEnv();
kenv.setNamed('Int', new KStar());
kenv.setNamed('String', new KStar());
kenv.setNamed('Bool', new KStar());
this.forwardDeclareKind(node, kenv);
this.inferKind(node, kenv);
const typeVars = new TVSet(); const typeVars = new TVSet();
const constraints = new ConstraintSet(); const constraints = new ConstraintSet();
const env = new TypeEnv(); const env = new TypeEnv();
@ -1475,7 +1835,10 @@ export class Checker {
this.pushContext(context); this.pushContext(context);
const a = this.createTypeVar(); const a = this.createTypeVar();
const b = this.createTypeVar();
const f = this.createTypeVar();
env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b)));
env.add('String', new Forall([], [], this.stringType)); env.add('String', new Forall([], [], this.stringType));
env.add('Int', new Forall([], [], this.intType)); env.add('Int', new Forall([], [], this.intType));
env.add('True', new Forall([], [], this.boolType)); env.add('True', new Forall([], [], this.boolType));

View file

@ -109,6 +109,7 @@ export const enum SyntaxKind {
ArrowTypeExpression, ArrowTypeExpression,
VarTypeExpression, VarTypeExpression,
AppTypeExpression, AppTypeExpression,
NestedTypeExpression,
// Patterns // Patterns
BindPattern, BindPattern,
@ -257,9 +258,17 @@ export class Scope {
break; break;
} }
case SyntaxKind.EnumDeclaration: case SyntaxKind.EnumDeclaration:
{
this.add(node.name.text, node, Symkind.Type);
if (node.members !== null) {
for (const member of node.members) {
this.add(member.name.text, member, Symkind.Constructor);
}
}
}
case SyntaxKind.StructDeclaration: case SyntaxKind.StructDeclaration:
{ {
this.add(node.name.text, node, Symkind.Constructor); this.add(node.name.text, node, Symkind.Constructor | Symkind.Type);
break; break;
} }
case SyntaxKind.LetDeclaration: case SyntaxKind.LetDeclaration:
@ -1032,11 +1041,34 @@ export class VarTypeExpression extends SyntaxBase {
} }
export class NestedTypeExpression extends SyntaxBase {
public readonly kind = SyntaxKind.NestedTypeExpression;
public constructor(
public lparen: LParen,
public typeExpr: TypeExpression,
public rparen: RParen,
) {
super();
}
public getFirstToken(): Token {
return this.lparen;
}
public getLastToken(): Token {
return this.rparen;
}
}
export type TypeExpression export type TypeExpression
= ReferenceTypeExpression = ReferenceTypeExpression
| ArrowTypeExpression | ArrowTypeExpression
| VarTypeExpression | VarTypeExpression
| AppTypeExpression | AppTypeExpression
| NestedTypeExpression
export class BindPattern extends SyntaxBase { export class BindPattern extends SyntaxBase {
@ -1730,6 +1762,7 @@ export class EnumDeclaration extends SyntaxBase {
public pubKeyword: PubKeyword | null, public pubKeyword: PubKeyword | null,
public enumKeyword: EnumKeyword, public enumKeyword: EnumKeyword,
public name: IdentifierAlt, public name: IdentifierAlt,
public varExps: Identifier[],
public members: EnumDeclarationElement[] | null, public members: EnumDeclarationElement[] | null,
) { ) {
super(); super();

View file

@ -1,6 +1,6 @@
import { describe } from "yargs"; import { describe } from "yargs";
import { TypeKind, type Type, type TArrow, TRecord } from "./checker"; import { TypeKind, type Type, type TArrow, TRecord, Kind, KindType } from "./checker";
import { Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst"; import { Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst";
import { countDigits, IndentWriter } from "./util"; import { countDigits, IndentWriter } from "./util";
@ -70,6 +70,10 @@ const DESCRIPTIONS: Partial<Record<SyntaxKind, string>> = {
[SyntaxKind.RBrace]: "'}'", [SyntaxKind.RBrace]: "'}'",
[SyntaxKind.LBracket]: "'['", [SyntaxKind.LBracket]: "'['",
[SyntaxKind.RBracket]: "']'", [SyntaxKind.RBracket]: "']'",
[SyntaxKind.StructKeyword]: "'struct'",
[SyntaxKind.EnumKeyword]: "'enum'",
[SyntaxKind.MatchKeyword]: "'match'",
[SyntaxKind.TypeKeyword]: "'type'",
[SyntaxKind.IdentifierAlt]: 'an identifier starting with an uppercase letter', [SyntaxKind.IdentifierAlt]: 'an identifier starting with an uppercase letter',
[SyntaxKind.ConstantExpression]: 'a constant expression', [SyntaxKind.ConstantExpression]: 'a constant expression',
[SyntaxKind.ReferenceExpression]: 'a reference expression', [SyntaxKind.ReferenceExpression]: 'a reference expression',
@ -196,6 +200,7 @@ export function describeType(type: Type): string {
} }
return out; return out;
} }
case TypeKind.Variant:
case TypeKind.Record: case TypeKind.Record:
{ {
return type.decl.name.text; return type.decl.name.text;
@ -220,6 +225,17 @@ export function describeType(type: Type): string {
} }
} }
function describeKind(kind: Kind): string {
switch (kind.type) {
case KindType.Var:
return `a${kind.id}`;
case KindType.Arrow:
return describeKind(kind.left) + ' -> ' + describeKind(kind.right);
case KindType.Star:
return '*';
}
}
function getFirstNodeInTypeChain(type: Type): Syntax | null { function getFirstNodeInTypeChain(type: Type): Syntax | null {
let curr = type.next; let curr = type.next;
while (curr !== type && (curr.kind === TypeKind.Var || curr.node === null)) { while (curr !== type && (curr.kind === TypeKind.Var || curr.node === null)) {
@ -344,8 +360,8 @@ export class KindMismatchDiagnostic {
public readonly level = Level.Error; public readonly level = Level.Error;
public constructor( public constructor(
public leftSize: number, public left: Kind,
public rightSize: number, public right: Kind,
public node: Syntax | null, public node: Syntax | null,
) { ) {
@ -353,15 +369,7 @@ export class KindMismatchDiagnostic {
public format(out: IndentWriter): void { public format(out: IndentWriter): void {
out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET); out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET);
out.write(`kind `); out.write(`kind ${describeKind(this.left)} does not match with ${describeKind(this.right)}\n\n`);
for (let i = 0; i < this.leftSize-1; i++) {
out.write(`* -> `);
}
out.write(`* does not match with `);
for (let i = 0; i < this.rightSize-1; i++) {
out.write(`* -> `);
}
out.write(`*\n\n`);
if (this.node !== null) { if (this.node !== null) {
out.write(printNode(this.node) + '\n'); out.write(printNode(this.node) + '\n');
} }

View file

@ -54,6 +54,8 @@ import {
VarTypeExpression, VarTypeExpression,
TypeDeclaration, TypeDeclaration,
AppTypeExpression, AppTypeExpression,
NestedPattern,
NestedTypeExpression,
} from "./cst" } from "./cst"
import { Stream } from "./util"; import { Stream } from "./util";
@ -176,6 +178,13 @@ export class Parser {
this.getToken(); this.getToken();
return new VarTypeExpression(t0); return new VarTypeExpression(t0);
} }
case SyntaxKind.LParen:
{
this.getToken();
const typeExpr = this.parseTypeExpression();
const t2 = this.expectToken(SyntaxKind.RParen);
return new NestedTypeExpression(t0, typeExpr, t2);
}
case SyntaxKind.IdentifierAlt: case SyntaxKind.IdentifierAlt:
return this.parseReferenceTypeExpression(); return this.parseReferenceTypeExpression();
default: default:
@ -477,11 +486,15 @@ export class Parser {
this.raiseParseError(t0, [ SyntaxKind.EnumKeyword ]); this.raiseParseError(t0, [ SyntaxKind.EnumKeyword ]);
} }
const name = this.expectToken(SyntaxKind.IdentifierAlt); const name = this.expectToken(SyntaxKind.IdentifierAlt);
const t1 = this.peekToken(); let t1 = this.getToken();
const varExps = [];
while (t1.kind === SyntaxKind.Identifier) {
varExps.push(t1);
t1 = this.getToken();
}
let members = null; let members = null;
if (t1.kind === SyntaxKind.BlockStart) { if (t1.kind === SyntaxKind.BlockStart) {
members = []; members = [];
this.getToken();
for (;;) { for (;;) {
const t2 = this.peekToken(); const t2 = this.peekToken();
if (t2.kind === SyntaxKind.BlockEnd) { if (t2.kind === SyntaxKind.BlockEnd) {
@ -498,6 +511,7 @@ export class Parser {
const name = this.expectToken(SyntaxKind.Identifier); const name = this.expectToken(SyntaxKind.Identifier);
const colon = this.expectToken(SyntaxKind.Colon); const colon = this.expectToken(SyntaxKind.Colon);
const typeExpr = this.parseTypeExpression(); const typeExpr = this.parseTypeExpression();
this.expectToken(SyntaxKind.LineFoldEnd);
members.push(new StructDeclarationField(name, colon, typeExpr)); members.push(new StructDeclarationField(name, colon, typeExpr));
const t4 = this.peekToken(); const t4 = this.peekToken();
if (t4.kind === SyntaxKind.BlockEnd) { if (t4.kind === SyntaxKind.BlockEnd) {
@ -521,9 +535,12 @@ export class Parser {
members.push(member); members.push(member);
this.expectToken(SyntaxKind.LineFoldEnd); this.expectToken(SyntaxKind.LineFoldEnd);
} }
t1 = this.getToken();
} }
this.expectToken(SyntaxKind.LineFoldEnd); if (t1.kind !== SyntaxKind.LineFoldEnd) {
return new EnumDeclaration(pubKeyword, t0, name, members); this.raiseParseError(t1, [ SyntaxKind.Identifier, SyntaxKind.BlockStart, SyntaxKind.LineFoldEnd ]);
}
return new EnumDeclaration(pubKeyword, t0, name, varExps, members);
} }
public parseStructDeclaration(): StructDeclaration { public parseStructDeclaration(): StructDeclaration {
@ -651,19 +668,26 @@ export class Parser {
switch (t0.kind) { switch (t0.kind) {
case SyntaxKind.LParen: case SyntaxKind.LParen:
{ {
const t1 = this.peekToken(); const t1 = this.peekToken(2);
if (t1.kind === SyntaxKind.IdentifierAlt) { if (t1.kind === SyntaxKind.IdentifierAlt) {
this.getToken(); this.getToken();
return this.parsePatternStartingWithConstructor(); const pattern = this.parsePatternStartingWithConstructor();
const t3 = this.expectToken(SyntaxKind.RParen);
return new NestedPattern(t0, pattern, t3);
} else { } else {
return this.parseTuplePattern(); return this.parseTuplePattern();
} }
} }
case SyntaxKind.IdentifierAlt: case SyntaxKind.IdentifierAlt:
return this.parsePatternStartingWithConstructor(); {
this.getToken();
return new NamedTuplePattern(t0, []);
}
case SyntaxKind.Identifier: case SyntaxKind.Identifier:
{
this.getToken(); this.getToken();
return new BindPattern(t0); return new BindPattern(t0);
}
default: default:
this.raiseParseError(t0, [ SyntaxKind.Identifier ]); this.raiseParseError(t0, [ SyntaxKind.Identifier ]);
} }
@ -866,7 +890,7 @@ export class Parser {
case SyntaxKind.StructKeyword: case SyntaxKind.StructKeyword:
return this.parseStructDeclaration(); return this.parseStructDeclaration();
case SyntaxKind.EnumKeyword: case SyntaxKind.EnumKeyword:
return this.parseStructDeclaration(); return this.parseEnumDeclaration();
case SyntaxKind.TypeKeyword: case SyntaxKind.TypeKeyword:
return this.parseTypeDeclaration(); return this.parseTypeDeclaration();
case SyntaxKind.IfKeyword: case SyntaxKind.IfKeyword: