From 2d10ceedc90eb0d45ba9bd68d1638db0edf83d82 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Thu, 15 Sep 2022 20:33:34 +0200 Subject: [PATCH] Multiple enhancements - Make record expressions anonymous - Introduce `TNominal` - Add experimental support for type declarations (fixes #32) - Fix inference of StructDeclaration --- src/checker.ts | 175 +++++++++++++++++---------------------------- src/cst.ts | 25 ++++--- src/diagnostics.ts | 15 +++- src/parser.ts | 97 ++++++++++--------------- 4 files changed, 130 insertions(+), 182 deletions(-) diff --git a/src/checker.ts b/src/checker.ts index d4d9e263d..18af9e9fd 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -1,4 +1,5 @@ import { + Declaration, EnumDeclaration, EnumDeclarationStructElement, Expression, @@ -25,6 +26,8 @@ import { import { assert, isEmpty, MultiMap } from "./util"; import { Analyser } from "./analysis"; +// TODO check that the order by which kindArgs are inserted is correct + const MAX_TYPE_ERROR_COUNT = 5; export enum TypeKind { @@ -36,7 +39,7 @@ export enum TypeKind { Labeled, Record, App, - Variant, + Nominal, } abstract class TypeBase { @@ -277,8 +280,6 @@ export class TRecord extends TypeBase { public readonly kind = TypeKind.Record; public constructor( - public decl: StructDeclaration | EnumDeclarationStructElement, - public kindArgs: TVar[], public fields: Map, public node: Syntax | null = null, ) { @@ -293,8 +294,6 @@ export class TRecord extends TypeBase { public shallowClone(): TRecord { return new TRecord( - this.decl, - this.kindArgs, this.fields, this.node ); @@ -302,15 +301,6 @@ export class TRecord extends TypeBase { public substitute(sub: TVSub): Type { let changed = false; - const newTypeVars = []; - for (const typeVar of this.kindArgs) { - const newTypeVar = typeVar.substitute(sub); - assert(newTypeVar.kind === TypeKind.Var); - if (newTypeVar !== typeVar) { - changed = true; - } - newTypeVars.push(newTypeVar); - } const newFields = new Map(); for (const [key, type] of this.fields) { const newType = type.substitute(sub); @@ -319,7 +309,7 @@ export class TRecord extends TypeBase { } newFields.set(key, newType); } - return changed ? new TRecord(this.decl, newTypeVars, newFields, this.node) : this; + return changed ? new TRecord(newFields, this.node) : this; } } @@ -336,12 +326,11 @@ export class TApp extends TypeBase { super(node); } - public static build(types: Type[], node: Syntax | null = null): Type { - let result = types[0]; - for (let i = 1; i < types.length; i++) { - result = new TApp(result, types[i], node); + public static build(resultType: Type, types: Type[], node: Syntax | null = null): Type { + for (let i = 0; i < types.length; i++) { + resultType = new TApp(types[i], resultType, node); } - return result; + return resultType; } public *getTypeVars(): Iterable { @@ -372,54 +361,30 @@ export class TApp extends TypeBase { } -export class TVariant extends TypeBase { +export class TNominal extends TypeBase { - public readonly kind = TypeKind.Variant; + public readonly kind = TypeKind.Nominal; public constructor( - public decl: EnumDeclaration, - public kindArgs: Type[], - public elementTypes: Type[], + public decl: Declaration, public node: Syntax | null = null, ) { super(node); } public *getTypeVars(): Iterable { - for (const elementType of this.elementTypes) { - yield* elementType.getTypeVars(); - } + } public shallowClone(): Type { - return new TVariant( + return new TNominal( this.decl, - this.kindArgs, - this.elementTypes, this.node, ); } public substitute(sub: TVSub): Type { - let changed = false; - const newTypeVars = []; - for (const kindArg of this.kindArgs) { - const newTypeVar = kindArg.substitute(sub); - assert(newTypeVar.kind === TypeKind.Var); - if (newTypeVar !== kindArg) { - changed = true; - } - newTypeVars.push(newTypeVar); - } - const newElementTypes = []; - for (const elementType of this.elementTypes) { - const newElementType = elementType.substitute(sub); - if (newElementType !== elementType) { - changed = true; - } - newElementTypes.push(newElementType); - } - return changed ? new TVariant(this.decl, newTypeVars, newElementTypes, this.node) : this; + return this; } } @@ -432,16 +397,7 @@ export type Type | TLabeled | TRecord | TApp - | TVariant - -type KindedType - = TRecord - | TVariant - -function isKindedType(type: Type): type is KindedType { - return type.kind === TypeKind.Variant - || type.kind === TypeKind.Record; -} + | TNominal export const enum KindType { Star, @@ -973,7 +929,7 @@ export class Checker { } } - + private inferKind(node: Syntax, env: KindEnv): void { switch (node.kind) { case SyntaxKind.SourceFile: @@ -985,7 +941,21 @@ export class Checker { } case SyntaxKind.StructDeclaration: { - // TODO + const declKind = env.lookup(node.name.text)!; + const innerEnv = new KindEnv(env); + let kind: Kind = new KStar(); + for (let i = node.varExps.length-1; i >= 0; i--) { + const varExpr = node.varExps[i]; + const paramKind = this.createKindVar(); + innerEnv.setNamed(varExpr.text, paramKind); + kind = new KArrow(paramKind, kind); + } + this.unifyKind(declKind, kind, node); + if (node.fields !== null) { + for (const field of node.fields) { + this.unifyKind(this.inferKindFromTypeExpression(field.typeExpr, innerEnv), new KStar(), field.typeExpr); + } + } break; } case SyntaxKind.EnumDeclaration: @@ -1194,15 +1164,15 @@ export class Checker { case SyntaxKind.ReferenceExpression: { - assert(node.name.modulePath.length === 0); + assert(node.modulePath.length === 0); const scope = node.getScope(); - const target = scope.lookup(node.name.name.text); + const target = scope.lookup(node.name.text); if (target !== null && target.kind === SyntaxKind.LetDeclaration && target.active) { return target.type!; } - const scheme = this.lookup(node.name.name.text, Symkind.Var); + const scheme = this.lookup(node.name.text, Symkind.Var); if (scheme === null) { - this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name)); + this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); return this.createTypeVar(); } const type = this.instantiate(scheme, node); @@ -1288,21 +1258,6 @@ export class Checker { case SyntaxKind.StructExpression: { - const scope = node.getScope(); - const decl = scope.lookup(node.name.text, Symkind.Type); - if (decl === null) { - this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); - return this.createTypeVar(); - } - assert(decl.kind === SyntaxKind.StructDeclaration || decl.kind === SyntaxKind.EnumDeclarationStructElement); - const scheme = decl.scheme!; - const declType = this.instantiate(scheme, node); - const kindArgs = []; - const varExps = decl.kind === SyntaxKind.StructDeclaration - ? decl.varExps : (decl.parent! as EnumDeclaration).varExps; - for (const _ of varExps) { - kindArgs.push(this.createTypeVar()); - } const fields = new Map(); for (const member of node.members) { switch (member.kind) { @@ -1328,19 +1283,7 @@ export class Checker { throw new Error(`Unexpected ${member}`); } } - let type: Type = TApp.build([ ...kindArgs, new TRecord(decl, [], fields, node) ]); - if (decl.kind === SyntaxKind.EnumDeclarationStructElement) { - // TODO - // type = this.buildVariantType(decl, type); - } - this.addConstraint( - new CEqual( - declType, - type, - node, - ) - ); - return type; + return new TRecord(fields, node); } case SyntaxKind.InfixExpression: @@ -1409,10 +1352,10 @@ export class Checker { case SyntaxKind.AppTypeExpression: { - return TApp.build([ - ...node.args.map(arg => this.inferTypeExpression(arg, introduceTypeVars)), + return TApp.build( this.inferTypeExpression(node.operator, introduceTypeVars), - ]); + node.args.map(arg => this.inferTypeExpression(arg, introduceTypeVars)), + ); } case SyntaxKind.ArrowTypeExpression: @@ -1550,7 +1493,7 @@ export class Checker { kindArgs.push(kindArg); } let elementTypes: Type[] = []; - const type = new TVariant(node, [], [], node); + const type = new TNominal(node, node); if (node.members !== null) { for (const member of node.members) { let elementType; @@ -1558,14 +1501,22 @@ export class Checker { case SyntaxKind.EnumDeclarationTupleElement: { const argTypes = member.elements.map(el => this.inferTypeExpression(el)); - elementType = new TArrow(argTypes, TApp.build([ ...kindArgs, type ])); - parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType), Symkind.Var); + elementType = new TArrow(argTypes, TApp.build(type, kindArgs)); + break; + } + case SyntaxKind.EnumDeclarationStructElement: + { + const fields = new Map(); + for (const field of member.fields) { + fields.set(field.name.text, this.inferTypeExpression(field.typeExpr)); + } + elementType = new TArrow([ new TRecord(fields, member) ], TApp.build(type, kindArgs)); break; } - // TODO default: throw new Error(`Unexpected ${member}`); } + parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType), Symkind.Var); elementTypes.push(elementType); } } @@ -1586,14 +1537,17 @@ export class Checker { returnType: null, }; this.pushContext(context); + const kindArgs = []; for (const varExpr of node.varExps) { - env.add(varExpr.text, new Forall([], [], this.createTypeVar()), Symkind.Type); + const typeVar = this.createTypeVar(); + kindArgs.push(typeVar); + env.add(varExpr.text, new Forall([], [], typeVar), Symkind.Type); } const type = this.inferTypeExpression(node.typeExpression); + console.log(describeType(type)); this.popContext(context); - const scheme = new Forall(typeVars, constraints, type); + const scheme = new Forall(typeVars, constraints, TApp.build(type, kindArgs)); parentEnv.add(node.name.text, scheme, Symkind.Type); - node.scheme = scheme; break; } @@ -1616,20 +1570,21 @@ export class Checker { kindArgs.push(kindArg); } const fields = new Map(); - if (node.members !== null) { - for (const member of node.members) { + if (node.fields !== null) { + for (const member of node.fields) { fields.set(member.name.text, this.inferTypeExpression(member.typeExpr)); } } this.popContext(context); - const type = new TRecord(node, [], fields, node); + const type = new TNominal(node); parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type); - node.scheme = new Forall(typeVars, constraints, TApp.build([ ...kindArgs, type ])); + parentEnv.add(node.name.text, new Forall(typeVars, constraints, new TArrow([ new TRecord(fields, node) ], TApp.build(type, kindArgs))), Symkind.Var); + //node.scheme = new Forall(typeVars, constraints, ); break; } default: - throw new Error(`Unexpected ${node}`); + throw new Error(`Unexpected ${node.constructor.name}`); } @@ -1956,7 +1911,7 @@ export class Checker { return success; } - if (left.kind === TypeKind.Variant && right.kind === TypeKind.Variant) { + if (left.kind === TypeKind.Nominal && right.kind === TypeKind.Nominal) { if (left.decl !== right.decl) { this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()])); return false; diff --git a/src/cst.ts b/src/cst.ts index 127624d06..a7564c26b 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -1434,7 +1434,6 @@ export class StructExpression extends SyntaxBase { public readonly kind = SyntaxKind.StructExpression; public constructor( - public name: IdentifierAlt, public lbrace: LBrace, public members: StructExpressionElement[], public rbrace: RBrace, @@ -1443,7 +1442,7 @@ export class StructExpression extends SyntaxBase { } public getFirstToken(): Token { - return this.name; + return this.lbrace; } public getLastToken(): Token { @@ -1481,17 +1480,21 @@ export class ReferenceExpression extends SyntaxBase { public readonly kind = SyntaxKind.ReferenceExpression; public constructor( - public name: QualifiedName, + public modulePath: Array<[IdentifierAlt, Dot]>, + public name: Identifier | IdentifierAlt, ) { super(); } public getFirstToken(): Token { - return this.name.getFirstToken(); + if (this.modulePath.length > 0) { + return this.modulePath[0][0]; + } + return this.name; } public getLastToken(): Token { - return this.name.getLastToken(); + return this.name; } } @@ -1718,7 +1721,7 @@ export class EnumDeclarationStructElement extends SyntaxBase { public constructor( public name: IdentifierAlt, public blockStart: BlockStart, - public members: StructDeclarationField[], + public fields: StructDeclarationField[], ) { super(); } @@ -1728,8 +1731,8 @@ export class EnumDeclarationStructElement extends SyntaxBase { } public getLastToken(): Token { - if (this.members.length > 0) { - return this.members[this.members.length-1].getLastToken(); + if (this.fields.length > 0) { + return this.fields[this.fields.length-1].getLastToken(); } return this.blockStart; } @@ -1830,7 +1833,7 @@ export class StructDeclaration extends SyntaxBase { public structKeyword: StructKeyword, public name: IdentifierAlt, public varExps: Identifier[], - public members: StructDeclarationField[] | null, + public fields: StructDeclarationField[] | null, ) { super(); } @@ -1843,8 +1846,8 @@ export class StructDeclaration extends SyntaxBase { } public getLastToken(): Token { - if (this.members && this.members.length > 0) { - return this.members[this.members.length-1].getLastToken(); + if (this.fields && this.fields.length > 0) { + return this.fields[this.fields.length-1].getLastToken(); } return this.name; } diff --git a/src/diagnostics.ts b/src/diagnostics.ts index f4db3ac2f..232b11ec4 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -1,5 +1,4 @@ -import { describe } from "yargs"; import { TypeKind, type Type, type TArrow, TRecord, Kind, KindType } from "./checker"; import { Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst"; import { countDigits, IndentWriter } from "./util"; @@ -200,11 +199,21 @@ export function describeType(type: Type): string { } return out; } - case TypeKind.Variant: - case TypeKind.Record: + case TypeKind.Nominal: { return type.decl.name.text; } + case TypeKind.Record: + { + let out = '{ '; + let first = true; + for (const [fieldName, fieldType] of type.fields) { + if (first) first = false; + else out += ', '; + out += fieldName + ': ' + describeType(fieldType); + } + return out + ' }'; + } case TypeKind.Labeled: { // FIXME may need to include fields that were added during unification diff --git a/src/parser.ts b/src/parser.ts index 8c725b8f4..fac3055c3 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -241,7 +241,7 @@ export class Parser { return new ConstantExpression(token); } - public parseQualifiedName(): QualifiedName { + public parseReferenceExpression(): ReferenceExpression { const modulePath: Array<[IdentifierAlt, Dot]> = []; for (;;) { const t0 = this.peekToken(1); @@ -251,12 +251,11 @@ export class Parser { } modulePath.push([t0, t1]); } - const name = this.expectToken(SyntaxKind.Identifier); - return new QualifiedName(modulePath, name); - } - - public parseReferenceExpression(): ReferenceExpression { - return new ReferenceExpression(this.parseQualifiedName()); + const name = this.getToken(); + if (name.kind !== SyntaxKind.Identifier && name.kind !== SyntaxKind.IdentifierAlt) { + this.raiseParseError(name, [ SyntaxKind.Identifier, SyntaxKind.IdentifierAlt ]); + } + return new ReferenceExpression(modulePath, name); } private parseExpressionWithParens(): Expression { @@ -279,65 +278,47 @@ export class Parser { case SyntaxKind.LParen: return this.parseExpressionWithParens(); case SyntaxKind.Identifier: - return this.parseReferenceExpression(); case SyntaxKind.IdentifierAlt: + return this.parseReferenceExpression(); + case SyntaxKind.LBrace: { this.getToken(); - const t1 = this.peekToken(); - if (t1.kind === SyntaxKind.LBrace) { - this.getToken(); - const fields = []; - let rbrace; - for (;;) { - const t2 = this.peekToken(); - if (t2.kind === SyntaxKind.RBrace) { - this.getToken(); - rbrace = t2; - break; - } - let field; - const t3 = this.getToken(); - if (t3.kind === SyntaxKind.Identifier) { - const t4 = this.peekToken(); - if (t4.kind === SyntaxKind.Equals) { - this.getToken(); - const expression = this.parseExpression(); - field = new StructExpressionField(t3, t4, expression); - } else { - field = new PunnedStructExpressionField(t3); - } - } else { - // TODO add spread fields - this.raiseParseError(t3, [ SyntaxKind.Identifier ]); - } - fields.push(field); - const t5 = this.peekToken(); - if (t5.kind === SyntaxKind.Comma) { - this.getToken(); - continue; - } else if (t5.kind === SyntaxKind.RBrace) { - this.getToken(); - rbrace = t5; - break; - } - } - return new StructExpression(t0, t1, fields, rbrace); - } - const elements = []; + const fields = []; + let rbrace; for (;;) { const t2 = this.peekToken(); - if (t2.kind === SyntaxKind.LineFoldEnd - || t2.kind === SyntaxKind.Comma - || t2.kind === SyntaxKind.RParen - || t2.kind === SyntaxKind.RBrace - || t2.kind === SyntaxKind.RBracket - || isBinaryOperatorLike(t2) - || isPrefixOperatorLike(t2)) { + if (t2.kind === SyntaxKind.RBrace) { + this.getToken(); + rbrace = t2; + break; + } + let field; + const t3 = this.getToken(); + if (t3.kind === SyntaxKind.Identifier) { + const t4 = this.peekToken(); + if (t4.kind === SyntaxKind.Equals) { + this.getToken(); + const expression = this.parseExpression(); + field = new StructExpressionField(t3, t4, expression); + } else { + field = new PunnedStructExpressionField(t3); + } + } else { + // TODO add spread fields + this.raiseParseError(t3, [ SyntaxKind.Identifier ]); + } + fields.push(field); + const t5 = this.peekToken(); + if (t5.kind === SyntaxKind.Comma) { + this.getToken(); + continue; + } else if (t5.kind === SyntaxKind.RBrace) { + this.getToken(); + rbrace = t5; break; } - elements.push(this.parseExpression()); } - return new NamedTupleExpression(t0, elements); + return new StructExpression(t0, fields, rbrace); } case SyntaxKind.Integer: case SyntaxKind.StringLiteral: