From 5f373a9d130629e206153b5aa1c0e12bdbe44595 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sat, 10 Sep 2022 16:52:14 +0200 Subject: [PATCH] Make struct-declarations polymorphic --- src/checker.ts | 46 ++++++++++++++++++++++++++++++++++++++++++---- src/cst.ts | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/parser.ts | 22 ++++++++++++++++++---- 3 files changed, 110 insertions(+), 8 deletions(-) diff --git a/src/checker.ts b/src/checker.ts index 67f41f996..01603f678 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -482,7 +482,7 @@ abstract class SchemeBase { class Forall extends SchemeBase { public constructor( - public tvs: TVar[], + public typeVars: TVar[], public constraints: Constraint[], public type: Type, ) { @@ -599,7 +599,7 @@ export class Checker { private instantiate(scheme: Scheme, node: Syntax | null): Type { const sub = new TVSub(); - for (const tv of scheme.tvs) { + for (const tv of scheme.typeVars) { sub.set(tv, this.createTypeVar()); } for (const constraint of scheme.constraints) { @@ -891,7 +891,7 @@ export class Checker { } - public inferTypeExpression(node: TypeExpression): Type { + public inferTypeExpression(node: TypeExpression, introduceTypeVars = false): Type { switch (node.kind) { @@ -907,6 +907,22 @@ export class Checker { return type; } + case SyntaxKind.VarTypeExpression: + { + const scheme = this.lookup(node.name.text); + if (scheme === null) { + if (!introduceTypeVars) { + this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); + } + const type = this.createTypeVar(); + this.addBinding(node.name.text, new Forall([], [], type)); + return type; + } + assert(scheme.typeVars.length === 0); + assert(scheme.constraints.length === 0); + return scheme.type; + } + case SyntaxKind.ArrowTypeExpression: { const paramTypes = []; @@ -1129,6 +1145,7 @@ export class Checker { break; } + case SyntaxKind.EnumDeclaration: case SyntaxKind.StructDeclaration: break; @@ -1169,6 +1186,7 @@ export class Checker { case SyntaxKind.IfStatement: case SyntaxKind.ReturnStatement: case SyntaxKind.ExpressionStatement: + case SyntaxKind.EnumDeclaration: case SyntaxKind.StructDeclaration: break; @@ -1208,16 +1226,36 @@ export class Checker { case SyntaxKind.ReturnStatement: break; + case SyntaxKind.EnumDeclaration: + { + // TODO complete this + break; + } + case SyntaxKind.StructDeclaration: { + const env = node.typeEnv = new TypeEnv(parentEnv); + const typeVars = new TVSet(); + const constraints = new ConstraintSet(); + const context: InferContext = { + constraints, + typeVars, + env, + returnType: null, + }; + this.pushContext(context); + for (const varExpr of node.typeVars) { + env.add(varExpr.text, new Forall([], [], this.createTypeVar())); + } const fields = new Map(); if (node.members !== null) { for (const member of node.members) { fields.set(member.name.text, this.inferTypeExpression(member.typeExpr)); } } + this.popContext(context); const type = new TRecord(node, fields); - parentEnv.add(node.name.text, new Forall([], [], type)); + parentEnv.add(node.name.text, new Forall(typeVars, constraints, type)); break; } diff --git a/src/cst.ts b/src/cst.ts index 7c61cacfc..8d9040c1c 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -107,6 +107,7 @@ export const enum SyntaxKind { // Type expressions ReferenceTypeExpression, ArrowTypeExpression, + VarTypeExpression, // Patterns BindPattern, @@ -240,6 +241,7 @@ export class Scope { case SyntaxKind.ReturnStatement: case SyntaxKind.IfStatement: break; + case SyntaxKind.EnumDeclaration: case SyntaxKind.StructDeclaration: break; case SyntaxKind.LetDeclaration: @@ -411,6 +413,32 @@ abstract class SyntaxBase { } +export function forEachChild(node: Syntax, callback: (node: Syntax) => void): void { + + for (const key of Object.getOwnPropertyNames(node)) { + if (isIgnoredProperty(key)) { + continue; + } + visitField((node as any)[key]); + } + + function visitField(field: any): void { + if (field === null) { + return; + } + if (Array.isArray(field)) { + for (const element of field) { + visitField(element); + } + return; + } + if (field instanceof SyntaxBase) { + callback(field as Syntax); + } + } + +} + abstract class TokenBase extends SyntaxBase { private endPos: TextPosition | null = null; @@ -939,9 +967,30 @@ export class ReferenceTypeExpression extends SyntaxBase { } +export class VarTypeExpression extends SyntaxBase { + + public readonly kind = SyntaxKind.VarTypeExpression; + + public constructor( + public name: Identifier + ) { + super(); + } + + public getFirstToken(): Token { + return this.name; + } + + public getLastToken(): Token { + return this.name; + } + +} + export type TypeExpression = ReferenceTypeExpression | ArrowTypeExpression + | VarTypeExpression export class BindPattern extends SyntaxBase { @@ -1686,6 +1735,7 @@ export class StructDeclaration extends SyntaxBase { public pubKeyword: PubKeyword | null, public structKeyword: StructKeyword, public name: IdentifierAlt, + public typeVars: Identifier[], public members: StructDeclarationField[] | null, ) { super(); diff --git a/src/parser.ts b/src/parser.ts index aa910df68..67703e6af 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -51,6 +51,7 @@ import { EnumDeclarationStructElement, EnumDeclaration, EnumDeclarationTupleElement, + VarTypeExpression, } from "./cst" import { Stream } from "./util"; @@ -168,6 +169,11 @@ export class Parser { public parsePrimitiveTypeExpression(): TypeExpression { const t0 = this.peekToken(); switch (t0.kind) { + case SyntaxKind.Identifier: + { + this.getToken(); + return new VarTypeExpression(t0); + } case SyntaxKind.IdentifierAlt: return this.parseReferenceTypeExpression(); default: @@ -288,6 +294,7 @@ export class Parser { for (;;) { const t2 = this.peekToken(); if (t2.kind === SyntaxKind.LineFoldEnd + || t2.kind === SyntaxKind.Comma || t2.kind === SyntaxKind.RParen || t2.kind === SyntaxKind.RBrace || t2.kind === SyntaxKind.RBracket @@ -481,10 +488,14 @@ export class Parser { this.raiseParseError(t0, [ SyntaxKind.StructKeyword ]); } const name = this.expectToken(SyntaxKind.IdentifierAlt); - const t2 = this.peekToken() + let t2 = this.getToken(); + const typeVars = []; + while (t2.kind === SyntaxKind.Identifier) { + typeVars.push(t2); + t2 = this.getToken(); + } let members = null; if (t2.kind === SyntaxKind.BlockStart) { - this.getToken(); members = []; for (;;) { const t3 = this.peekToken(); @@ -499,9 +510,12 @@ export class Parser { const member = new StructDeclarationField(name, colon, typeExpr); members.push(member); } + t2 = this.getToken(); } - this.expectToken(SyntaxKind.LineFoldEnd); - return new StructDeclaration(pubKeyword, t0, name, members); + if (t2.kind !== SyntaxKind.LineFoldEnd) { + this.raiseParseError(t2, [ SyntaxKind.LineFoldEnd, SyntaxKind.BlockStart, SyntaxKind.Identifier ]); + } + return new StructDeclaration(pubKeyword, t0, name, typeVars, members); } private parsePatternStartingWithConstructor() {