From 00bcaa93eed732b80978302dff672dfce1f2d32b Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Wed, 7 Sep 2022 12:45:38 +0200 Subject: [PATCH] Make record types partially work --- src/checker.ts | 398 +++++++++++++++++++++++++++++++++++++++++++-- src/cst.ts | 29 +++- src/diagnostics.ts | 15 ++ src/parser.ts | 22 ++- src/scanner.ts | 2 + 5 files changed, 442 insertions(+), 24 deletions(-) diff --git a/src/checker.ts b/src/checker.ts index 8cf91d828..d16ea8ee1 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -2,7 +2,9 @@ import { Expression, LetDeclaration, Pattern, + Scope, SourceFile, + StructDeclaration, Syntax, SyntaxKind, TypeExpression @@ -25,6 +27,8 @@ export enum TypeKind { Con, Any, Tuple, + Labeled, + Record, } abstract class TypeBase { @@ -183,13 +187,82 @@ class TTuple extends TypeBase { } +class TLabeled extends TypeBase { + + public readonly kind = TypeKind.Labeled; + + public fields?: Map; + public parent: TLabeled | null = null; + + public constructor( + public name: string, + public type: Type, + ) { + super(); + } + + public find(): TLabeled { + let curr: TLabeled | null = this; + while (curr.parent !== null) { + curr = curr.parent; + } + this.parent = curr; + return curr; + } + + public getTypeVars(): Iterable { + return this.type.getTypeVars(); + } + + public substitute(sub: TVSub): Type { + const newType = this.type.substitute(sub); + return newType !== this.type ? new TLabeled(this.name, newType) : this; + } + +} + +class TRecord extends TypeBase { + + public readonly kind = TypeKind.Record; + + public nextRecord: TRecord | null = null; + + public constructor( + public decl: StructDeclaration, + public fields: Map, + ) { + super(); + } + + public *getTypeVars(): Iterable { + for (const type of this.fields.values()) { + yield* type.getTypeVars(); + } + } + + public substitute(sub: TVSub): Type { + let changed = false; + const newFields = new Map(); + for (const [key, type] of this.fields) { + const newType = type.substitute(sub); + if (newType !== type) { + changed = true; + } + newFields.set(key, newType); + } + return changed ? new TRecord(this.decl, newFields) : this; + } + +} + export type Type = TCon | TArrow | TVar | TAny | TTuple - + | TLabeled + | TRecord class TVSet { @@ -251,6 +324,7 @@ class TVSub { const enum ConstraintKind { Equal, Many, + Shaped, } abstract class ConstraintBase { @@ -276,6 +350,26 @@ abstract class ConstraintBase { } +class CShaped extends ConstraintBase { + + public readonly kind = ConstraintKind.Shaped; + + public constructor( + public recordType: TLabeled, + public type: Type, + ) { + super(); + } + + public substitute(sub: TVSub): Constraint { + return new CShaped( + this.recordType.substitute(sub) as TLabeled, + this.type.substitute(sub), + ); + } + +} + class CEqual extends ConstraintBase { public readonly kind = ConstraintKind.Equal; @@ -325,6 +419,7 @@ class CMany extends ConstraintBase { type Constraint = CEqual | CMany + | CShaped class ConstraintSet extends Array { } @@ -384,6 +479,7 @@ export class Checker { private nextTypeVarId = 0; private nextConTypeId = 0; + private nextRecordTypeId = 0; //private graph?: Graph; //private currentCycle?: Map; @@ -520,10 +616,49 @@ export class Checker { } case SyntaxKind.LetDeclaration: + { + if (node.pattern.kind === SyntaxKind.BindPattern) { + break; + } + const type = this.inferBindings(node.pattern, [], []); + if (node.typeAssert !== null) { + this.addConstraint( + new CEqual( + this.inferTypeExpression(node.typeAssert.typeExpression), + type, + node + ) + ); + } + if (node.body !== null) { + switch (node.body.kind) { + case SyntaxKind.ExprBody: + { + const type2 = this.inferExpression(node.body.expression); + this.addConstraint( + new CEqual( + type, + type2, + node + ) + ); + break; + } + case SyntaxKind.BlockBody: + { + // TODO + assert(false); + } + } + } + break; + } + + case SyntaxKind.StructDeclaration: break; default: - throw new Error(`Unexpected ${node}`); + throw new Error(`Unexpected ${node.constructor.name}`); } @@ -600,6 +735,64 @@ export class Checker { return new TCon(type.id, argTypes, type.displayName); } + case SyntaxKind.StructExpression: + { + const scheme = this.lookup(node.name.text); + if (scheme === null) { + this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); + return new TAny(); + } + const recordType = this.instantiate(scheme, node); + const type = this.createTypeVar(); + for (const member of node.members) { + switch (member.kind) { + case SyntaxKind.StructExpressionField: + { + this.addConstraint( + new CEqual( + new TLabeled( + member.name.text, + this.inferExpression(member.expression) + ), + type, + member, + ) + ); + break; + } + case SyntaxKind.PunnedStructExpressionField: + { + const scheme = this.lookup(member.name.text); + let fieldType; + if (scheme === null) { + this.diagnostics.add(new BindingNotFoudDiagnostic(member.name.text, member.name)); + fieldType = new TAny(); + } else { + fieldType = this.instantiate(scheme, member); + } + this.addConstraint( + new CEqual( + fieldType, + type, + member + ) + ); + break; + } + default: + throw new Error(`Unexpected ${member}`); + } + } + this.addConstraint( + new CEqual( + recordType, + type, + node, + ) + ); + return type; + } + case SyntaxKind.InfixExpression: { const scheme = this.lookup(node.operator.text); @@ -649,22 +842,87 @@ export class Checker { } - public inferBindings(pattern: Pattern, type: Type, tvs: TVar[], constraints: Constraint[]): void { + public inferBindings(pattern: Pattern, typeVars: TVar[], constraints: Constraint[]): Type { switch (pattern.kind) { case SyntaxKind.BindPattern: { - this.addBinding(pattern.name.text, new Forall(tvs, constraints, type)); - break; + const type = this.createTypeVar(); + this.addBinding(pattern.name.text, new Forall(typeVars, constraints, type)); + return type; } + case SyntaxKind.StructPattern: + { + const scheme = this.lookup(pattern.name.text); + let recordType; + if (scheme === null) { + this.diagnostics.add(new BindingNotFoudDiagnostic(pattern.name.text, pattern.name)); + recordType = new TAny(); + } else { + recordType = this.instantiate(scheme, pattern.name); + } + const type = this.createTypeVar(); + for (const member of pattern.members) { + switch (member.kind) { + case SyntaxKind.StructPatternField: + { + const fieldType = this.inferBindings(member.pattern, typeVars, constraints); + this.addConstraint( + new CEqual( + new TLabeled(member.name.text, fieldType), + type, + member + ) + ); + break; + } + case SyntaxKind.PunnedStructPatternField: + { + const fieldType = this.createTypeVar(); + this.addBinding(member.name.text, new Forall([], [], fieldType)); + this.addConstraint( + new CEqual( + new TLabeled(member.name.text, fieldType), + type, + member + ) + ); + break; + } + default: + throw new Error(`Unexpected ${member.constructor.name}`); + } + } + this.addConstraint( + new CEqual( + recordType, + type, + pattern + ) + ); + return type; + } + + default: + throw new Error(`Unexpected ${pattern.constructor.name}`); + } } private addReferencesToGraph(graph: ReferenceGraph, node: Syntax, source: LetDeclaration | SourceFile) { + const addReference = (scope: Scope, name: string) => { + const target = scope.lookup(name); + if (target === null || target.kind === SyntaxKind.Param) { + return; + } + assert(target.kind === SyntaxKind.LetDeclaration || target.kind === SyntaxKind.SourceFile); + graph.addEdge(source, target, true); + } + switch (node.kind) { case SyntaxKind.ConstantExpression: @@ -681,12 +939,7 @@ export class Checker { case SyntaxKind.ReferenceExpression: { assert(node.name.modulePath.length === 0); - const target = node.getScope().lookup(node.name.name.text); - if (target === null || target.kind === SyntaxKind.Param) { - break; - } - assert(target.kind === SyntaxKind.LetDeclaration || target.kind === SyntaxKind.SourceFile); - graph.addEdge(source, target, true); + addReference(node.getScope(), node.name.name.text); break; } @@ -698,6 +951,25 @@ export class Checker { break; } + case SyntaxKind.StructExpression: + { + for (const member of node.members) { + switch (member.kind) { + case SyntaxKind.PunnedStructExpressionField: + { + addReference(node.getScope(), node.name.text); + break; + } + case SyntaxKind.StructExpressionField: + { + this.addReferencesToGraph(graph, member.expression, source); + break; + }; + } + } + break; + } + case SyntaxKind.NestedExpression: { this.addReferencesToGraph(graph, node.expression, source); @@ -738,6 +1010,7 @@ export class Checker { this.addReferencesToGraph(graph, node.expression, source); break; } + case SyntaxKind.ReturnStatement: { if (node.expression !== null) { @@ -745,6 +1018,7 @@ export class Checker { } break; } + case SyntaxKind.LetDeclaration: { graph.addVertex(node); @@ -767,6 +1041,9 @@ export class Checker { break; } + case SyntaxKind.StructDeclaration: + break; + default: throw new Error(`Unexpected ${node.constructor.name}`); @@ -804,6 +1081,7 @@ export class Checker { case SyntaxKind.IfStatement: case SyntaxKind.ReturnStatement: case SyntaxKind.ExpressionStatement: + case SyntaxKind.StructDeclaration: break; default: @@ -813,14 +1091,15 @@ export class Checker { } - private initialize(node: Syntax, parentEnv: TypeEnv | null): void { + private initialize(node: Syntax, parentEnv: TypeEnv): void { switch (node.kind) { case SyntaxKind.SourceFile: { + const env = node.typeEnv = new TypeEnv(parentEnv); for (const element of node.elements) { - this.initialize(element, parentEnv); + this.initialize(element, env); } break; } @@ -839,9 +1118,21 @@ export class Checker { case SyntaxKind.IfStatement: case SyntaxKind.ExpressionStatement: case SyntaxKind.ReturnStatement: - case SyntaxKind.StructDeclaration: break; + case SyntaxKind.StructDeclaration: + { + const fields = new Map(); + if (node.members !== null) { + for (const member of node.members) { + fields.set(member.name.text, this.inferTypeExpression(member.typeExpr)); + } + } + const type = new TRecord(node, fields); + parentEnv.add(node.name.text, new Forall([], [], type)); + break; + } + default: throw new Error(`Unexpected ${node}`); @@ -877,6 +1168,13 @@ export class Checker { this.initialize(node, env); + this.pushContext({ + typeVars, + constraints, + env: node.typeEnv!, + returnType: null + }); + const sccs = [...strongconnect(graph)]; for (const nodes of sccs) { @@ -909,8 +1207,7 @@ export class Checker { const paramTypes = []; for (const param of node.params) { - const paramType = this.createTypeVar() - this.inferBindings(param.pattern, paramType, [], []); + const paramType = this.inferBindings(param.pattern, [], []); paramTypes.push(paramType); } @@ -928,7 +1225,18 @@ export class Checker { this.contexts.pop(); - this.inferBindings(node.pattern, type, typeVars, constraints); + // FIXME get rid of all this useless stack manipulation + const parentDecl = node.parent!.getScope().node; + const bindCtx = { + typeVars: context.typeVars, + constraints: context.constraints, + env: parentDecl.typeEnv!, + returnType: null, + }; + this.contexts.push(bindCtx) + const ty2 = this.inferBindings(node.pattern, typeVars, constraints); + this.addConstraint(new CEqual(ty2, type, node)); + this.contexts.pop(); } } @@ -1006,6 +1314,7 @@ export class Checker { } } + this.contexts.pop(); this.popContext(context); this.solve(new CMany(constraints), this.solution); @@ -1033,7 +1342,6 @@ export class Checker { case ConstraintKind.Equal: { - //constraint.dump(); if (!this.unify(constraint.left, constraint.right, solution, constraint)) { errorCount++; if (errorCount === MAX_TYPE_ERROR_COUNT) { @@ -1114,6 +1422,60 @@ export class Checker { } } + if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) { + //const remaining = new Set(right.fields.keys()); + let success = false; + + const root = left.find(); + right.parent = root; + if (root.fields === undefined) { + root.fields = new Map([ [ root.name, root.type ] ]); + } + if (right.fields === undefined) { + right.fields = new Map([ [ right.name, right.type ] ]); + } + for (const [fieldName, fieldType] of right.fields) { + if (root.fields.has(fieldName)) { + if (!this.unify(root.fields.get(fieldName)!, fieldType, solution, constraint)) { + success = false; + } + } else { + root.fields.set(fieldName, fieldType); + } + } + delete right.fields; + return success; + } + + + if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) { + let success = true; + if (right.fields === undefined) { + right.fields = new Map([ [ right.name, right.type ] ]); + } + const remaining = new Set(right.fields.keys()); + for (const [fieldName, fieldType] of left.fields) { + if (right.fields.has(fieldName)) { + if (!this.unify(fieldType, right.fields.get(fieldName)!, solution, constraint)) { + success = false; + } + remaining.delete(fieldName); + } + } + for (const fieldName of remaining) { + if (left.fields.has(fieldName)) { + if (!this.unify(left.fields.get(fieldName)!, right.fields.get(fieldName)!, solution, constraint)) { + success = false; + } + } + } + return success; + } + + if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Record) { + return this.unify(right, left, solution, constraint); + } + this.diagnostics.add( new UnificationFailedDiagnostic( left.substitute(solution), diff --git a/src/cst.ts b/src/cst.ts index 1625bc236..a35f276e5 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -178,6 +178,8 @@ export type Syntax | Expression | TypeExpression | Pattern + | StructExpressionElement + | StructPatternElement function isIgnoredProperty(key: string): boolean { return key === 'kind' || key === 'parent'; @@ -226,6 +228,8 @@ export class Scope { case SyntaxKind.ReturnStatement: case SyntaxKind.IfStatement: break; + case SyntaxKind.StructDeclaration: + break; case SyntaxKind.LetDeclaration: { for (const param of node.params) { @@ -254,6 +258,24 @@ export class Scope { this.mapping.set(node.name.text, decl); break; } + case SyntaxKind.StructPattern: + { + for (const member of node.members) { + switch (member.kind) { + case SyntaxKind.StructPatternField: + { + this.scanPattern(member.pattern, decl); + break; + } + case SyntaxKind.PunnedStructPatternField: + { + this.mapping.set(node.name.text, decl); + break; + } + } + } + break; + } default: throw new Error(`Unexpected ${node}`); } @@ -1006,7 +1028,7 @@ export class StructPattern extends SyntaxBase { public constructor( public name: IdentifierAlt, public lbrace: LBrace, - public elements: StructPatternElement[], + public members: StructPatternElement[], public rbrace: RBrace, ) { super(); @@ -1216,7 +1238,7 @@ export class StructExpression extends SyntaxBase { public constructor( public name: IdentifierAlt, public lbrace: LBrace, - public elements: StructExpressionElement[], + public members: StructExpressionElement[], public rbrace: RBrace, ) { super(); @@ -1495,7 +1517,7 @@ export class StructDeclaration extends SyntaxBase { public constructor( public structKeyword: StructKeyword, - public name: Identifier, + public name: IdentifierAlt, public members: StructDeclarationField[] | null, ) { super(); @@ -1714,6 +1736,7 @@ export class SourceFile extends SyntaxBase { public readonly kind = SyntaxKind.SourceFile; public scope?: Scope; + public typeEnv?: TypeEnv; public constructor( private file: TextFile, diff --git a/src/diagnostics.ts b/src/diagnostics.ts index 76ad0173d..4bde0bdee 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -182,6 +182,21 @@ export function describeType(type: Type): string { } return out; } + case TypeKind.Record: + { + let out = type.decl.name.text + ' { '; + let first = true; + for (const [fieldName, fieldType] of type.fields) { + if (first) first = false; + else out += ', '; + out += fieldName + ': ' + describeType(fieldType); + } + return out + ' }'; + } + case TypeKind.Labeled: + { + return '{ ' + type.name + ': ' + describeType(type.type) + ' }'; + } } } diff --git a/src/parser.ts b/src/parser.ts index d68c70ec4..cacffabd2 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -229,6 +229,7 @@ export class Parser { for (;;) { const t2 = this.peekToken(); if (t2.kind === SyntaxKind.RBrace) { + this.getToken(); rbrace = t2; break; } @@ -253,6 +254,7 @@ export class Parser { this.getToken(); continue; } else if (t5.kind === SyntaxKind.RBrace) { + this.getToken(); rbrace = t5; break; } @@ -264,6 +266,8 @@ export class Parser { const t2 = this.peekToken(); if (t2.kind === SyntaxKind.LineFoldEnd || t2.kind === SyntaxKind.RParen + || t2.kind === SyntaxKind.RBrace + || t2.kind === SyntaxKind.RBracket || isBinaryOperatorLike(t2) || isPrefixOperatorLike(t2)) { break; @@ -292,8 +296,11 @@ export class Parser { for (;;) { const t1 = this.peekToken(); if (t1.kind === SyntaxKind.LineFoldEnd + || t1.kind === SyntaxKind.RBrace + || t1.kind === SyntaxKind.RBracket || t1.kind === SyntaxKind.RParen || t1.kind === SyntaxKind.BlockStart + || t1.kind === SyntaxKind.Comma || isBinaryOperatorLike(t1) || isPrefixOperatorLike(t1)) { break; @@ -364,22 +371,27 @@ export class Parser { public parseStructDeclaration(): StructDeclaration { const structKeyword = this.expectToken(SyntaxKind.StructKeyword); - const name = this.expectToken(SyntaxKind.Identifier); + const name = this.expectToken(SyntaxKind.IdentifierAlt); const t2 = this.peekToken() let members = null; if (t2.kind === SyntaxKind.BlockStart) { this.getToken(); members = []; for (;;) { + const t3 = this.peekToken(); + if (t3.kind === SyntaxKind.BlockEnd) { + this.getToken(); + break; + } const name = this.expectToken(SyntaxKind.Identifier); const colon = this.expectToken(SyntaxKind.Colon); const typeExpr = this.parseTypeExpression(); + this.expectToken(SyntaxKind.LineFoldEnd); const member = new StructDeclarationField(name, colon, typeExpr); members.push(member); } - } else { - this.assertToken(t2, SyntaxKind.LineFoldEnd); } + this.expectToken(SyntaxKind.LineFoldEnd); return new StructDeclaration(structKeyword, name, members); } @@ -393,6 +405,7 @@ export class Parser { for (;;) { const t3 = this.peekToken(); if (t3.kind === SyntaxKind.RBrace) { + this.getToken(); rbrace = t3; break; } else if (t3.kind === SyntaxKind.Identifier) { @@ -415,6 +428,7 @@ export class Parser { if (t5.kind === SyntaxKind.Comma) { this.getToken(); } else if (t5.kind === SyntaxKind.RBrace) { + this.getToken(); rbrace = t5; break; } else { @@ -473,6 +487,8 @@ export class Parser { return this.parseTuplePattern(); } } + case SyntaxKind.IdentifierAlt: + return this.parsePatternStartingWithConstructor(); case SyntaxKind.Identifier: this.getToken(); return new BindPattern(t0); diff --git a/src/scanner.ts b/src/scanner.ts index 4f89bd926..354a3d89a 100644 --- a/src/scanner.ts +++ b/src/scanner.ts @@ -34,6 +34,7 @@ import { ElifKeyword, ElseKeyword, IfKeyword, + StructKeyword, } from "./cst" import { Diagnostics, UnexpectedCharDiagnostic } from "./diagnostics" import { Stream, BufferedStream, assert } from "./util"; @@ -351,6 +352,7 @@ export class Scanner extends BufferedStream { case 'if': return new IfKeyword(startPos); case 'else': return new ElseKeyword(startPos); case 'elif': return new ElifKeyword(startPos); + case 'struct': return new StructKeyword(startPos); default: if (isUpper(text[0])) { return new IdentifierAlt(text, startPos);