From 7b3f1948bb87c7aca5196d23117cc2b5b5ebab14 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Fri, 9 Sep 2022 22:37:14 +0200 Subject: [PATCH] Remove TAny; support operator declarations and arrow type expressions --- src/bin/bolt-selftest.ts | 2 + src/checker.ts | 101 ++++++++++++++++++++++++--------------- src/cst.ts | 76 +++++++++++++++++++++++++++-- src/parser.ts | 59 +++++++++++++++++------ src/scanner.ts | 5 +- 5 files changed, 186 insertions(+), 57 deletions(-) diff --git a/src/bin/bolt-selftest.ts b/src/bin/bolt-selftest.ts index b93e94720..9ad56981f 100644 --- a/src/bin/bolt-selftest.ts +++ b/src/bin/bolt-selftest.ts @@ -1,5 +1,7 @@ #!/usr/bin/env node +import "source-map-support/register" + import * as commonmark from "commonmark" import { sync as globSync } from "glob" import fs from "fs"; diff --git a/src/checker.ts b/src/checker.ts index a94ad16b0..67f41f996 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -4,6 +4,7 @@ import { Pattern, Scope, SourceFile, + SourceFileElement, StructDeclaration, Syntax, SyntaxKind, @@ -526,11 +527,15 @@ export interface InferContext { returnType: Type | null; } +function isFunctionDeclarationLike(node: LetDeclaration): boolean { + return node.pattern.kind === SyntaxKind.BindPattern + && (node.params.length > 0 || (node.body !== null && node.body.kind === SyntaxKind.BlockBody)); +} + export class Checker { private nextTypeVarId = 0; private nextConTypeId = 0; - private nextRecordTypeId = 0; //private graph?: Graph; //private currentCycle?: Map; @@ -668,10 +673,16 @@ export class Checker { case SyntaxKind.LetDeclaration: { - if (node.pattern.kind === SyntaxKind.BindPattern) { + if (isFunctionDeclarationLike(node)) { break; } - const type = this.inferBindings(node.pattern, [], []); + let type; + if (node.pattern.kind === SyntaxKind.WrappedOperator) { + type = this.createTypeVar(); + this.addBinding(node.pattern.operator.text, new Forall([], [], type)); + } else { + type = this.inferBindings(node.pattern, [], []); + } if (node.typeAssert !== null) { this.addConstraint( new CEqual( @@ -733,7 +744,7 @@ export class Checker { const scheme = this.lookup(node.name.name.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name)); - return new TAny(); + return this.createTypeVar(); } const type = this.instantiate(scheme, node); type.node = node; @@ -812,7 +823,7 @@ export class Checker { const scheme = this.lookup(node.name.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); - return new TAny(); + return this.createTypeVar(); } const recordType = this.instantiate(scheme, node); assert(recordType.kind === TypeKind.Record); @@ -830,7 +841,7 @@ export class Checker { let fieldType; if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(member.name.text, member.name)); - fieldType = new TAny(); + fieldType = this.createTypeVar(); } else { fieldType = this.instantiate(scheme, member); } @@ -857,7 +868,7 @@ export class Checker { const scheme = this.lookup(node.operator.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.operator.text, node.operator)); - return new TAny(); + return this.createTypeVar(); } const opType = this.instantiate(scheme, node.operator); const retType = this.createTypeVar(); @@ -889,13 +900,23 @@ export class Checker { const scheme = this.lookup(node.name.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); - return new TAny(); + return this.createTypeVar(); } const type = this.instantiate(scheme, node.name); type.node = node; return type; } + case SyntaxKind.ArrowTypeExpression: + { + const paramTypes = []; + for (const paramTypeExpr of node.paramTypeExprs) { + paramTypes.push(this.inferTypeExpression(paramTypeExpr)); + } + const returnType = this.inferTypeExpression(node.returnTypeExpr); + return new TArrow(paramTypes, returnType, node); + } + default: throw new Error(`Unrecognised ${node}`); @@ -920,7 +941,7 @@ export class Checker { let recordType; if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(pattern.name.text, pattern.name)); - recordType = new TAny(); + recordType = this.createTypeVar(); } else { recordType = this.instantiate(scheme, pattern.name); } @@ -1258,6 +1279,10 @@ export class Checker { assert(node.kind === SyntaxKind.LetDeclaration); + if (!isFunctionDeclarationLike(node)) { + continue; + } + const env = node.typeEnv!; const context: InferContext = { typeVars, @@ -1301,13 +1326,34 @@ export class Checker { returnType: null, }; this.contexts.push(bindCtx) - const ty2 = this.inferBindings(node.pattern, typeVars, constraints); + let ty2; + if (node.pattern.kind === SyntaxKind.WrappedOperator) { + ty2 = this.createTypeVar(); + this.addBinding(node.pattern.operator.text, new Forall([], [], ty2)); + } else { + ty2 = this.inferBindings(node.pattern, typeVars, constraints); + } this.addConstraint(new CEqual(ty2, type, node)); this.contexts.pop(); } } + const visitElements = (elements: Syntax[]) => { + for (const element of elements) { + if (element.kind === SyntaxKind.LetDeclaration + && isFunctionDeclarationLike(element) + && graph.hasEdge(node, element, false)) { + assert(element.pattern.kind === SyntaxKind.BindPattern); + const scheme = this.lookup(element.pattern.name.text); + assert(scheme !== null); + this.instantiate(scheme, null); + } else { + this.infer(element); + } + } + } + for (const nodes of sccs) { if (nodes.some(n => n.kind === SyntaxKind.SourceFile)) { @@ -1324,6 +1370,10 @@ export class Checker { assert(node.kind === SyntaxKind.LetDeclaration); + if (!isFunctionDeclarationLike(node)) { + continue; + } + const context = node.context!; const returnType = context.returnType!; this.contexts.push(context); @@ -1343,17 +1393,7 @@ export class Checker { } case SyntaxKind.BlockBody: { - for (const element of node.body.elements) { - if (element.kind === SyntaxKind.LetDeclaration - && element.pattern.kind === SyntaxKind.BindPattern - && graph.hasEdge(node, element, false)) { - const scheme = this.lookup(element.pattern.name.text); - assert(scheme !== null); - this.instantiate(scheme, null); - } else { - this.infer(element); - } - } + visitElements(node.body.elements); break; } } @@ -1368,18 +1408,8 @@ export class Checker { } } - - for (const element of node.elements) { - if (element.kind === SyntaxKind.LetDeclaration - && element.pattern.kind === SyntaxKind.BindPattern - && graph.hasEdge(node, element, false)) { - const scheme = this.lookup(element.pattern.name.text); - assert(scheme !== null); - this.instantiate(scheme, null); - } else { - this.infer(element); - } - } + + visitElements(node.elements); this.contexts.pop(); this.popContext(context); @@ -1447,11 +1477,6 @@ export class Checker { return this.unify(right, left, solution, constraint); } - if (left.kind === TypeKind.Any || right.kind === TypeKind.Any) { - TypeBase.join(left, right); - return true; - } - if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) { if (left.paramTypes.length !== right.paramTypes.length) { this.diagnostics.add(new ArityMismatchDiagnostic(left, right)); diff --git a/src/cst.ts b/src/cst.ts index e80596b53..2224f7c2f 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -78,6 +78,7 @@ export const enum SyntaxKind { RBrace, LBracket, RBracket, + RArrow, Dot, DotDot, Comma, @@ -104,6 +105,7 @@ export const enum SyntaxKind { // Type expressions ReferenceTypeExpression, + ArrowTypeExpression, // Patterns BindPattern, @@ -139,6 +141,9 @@ export const enum SyntaxKind { ExpressionStatement, IfStatement, + // If statement elements + IfStatementCase, + // Declarations VariableDeclaration, PrefixFuncDecl, @@ -156,7 +161,7 @@ export const enum SyntaxKind { StructDeclarationField, // Other nodes - IfStatementCase, + WrappedOperator, Initializer, QualifiedName, TypeAssert, @@ -243,7 +248,11 @@ export class Scope { } } } else { - this.scanPattern(node.pattern, node); + if (node.pattern.kind === SyntaxKind.WrappedOperator) { + this.mapping.set(node.pattern.operator.text, node); + } else { + this.scanPattern(node.pattern, node); + } } break; } @@ -816,8 +825,19 @@ export class LetKeyword extends TokenBase { } +export class RArrow extends TokenBase { + + public readonly kind = SyntaxKind.RArrow; + + public get text(): string { + return '->'; + } + +} + export type Token - = LParen + = RArrow + | LParen | RParen | LBrace | RBrace @@ -854,6 +874,30 @@ export type Token export type TokenKind = Token['kind'] +export class ArrowTypeExpression extends SyntaxBase { + + public readonly kind = SyntaxKind.ArrowTypeExpression; + + public constructor( + public paramTypeExprs: TypeExpression[], + public returnTypeExpr: TypeExpression + ) { + super(); + } + + public getFirstToken(): Token { + if (this.paramTypeExprs.length > 0) { + return this.paramTypeExprs[0].getFirstToken(); + } + return this.returnTypeExpr.getFirstToken(); + } + + public getLastToken(): Token { + return this.returnTypeExpr.getLastToken(); + } + +} + export class ReferenceTypeExpression extends SyntaxBase { public readonly kind = SyntaxKind.ReferenceTypeExpression; @@ -880,6 +924,7 @@ export class ReferenceTypeExpression extends SyntaxBase { export type TypeExpression = ReferenceTypeExpression + | ArrowTypeExpression export class BindPattern extends SyntaxBase { @@ -1632,6 +1677,29 @@ export class BlockBody extends SyntaxBase { } } + +export class WrappedOperator extends SyntaxBase { + + public readonly kind = SyntaxKind.WrappedOperator; + + public constructor( + public lparen: LParen, + public operator: CustomOperator, + public rparen: RParen, + ) { + super(); + } + + public getFirstToken(): Token { + return this.lparen; + } + + public getLastToken(): Token { + return this.rparen; + } + +} + export class LetDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.LetDeclaration; @@ -1646,7 +1714,7 @@ export class LetDeclaration extends SyntaxBase { public pubKeyword: PubKeyword | null, public letKeyword: LetKeyword, public mutKeyword: MutKeyword | null, - public pattern: Pattern, + public pattern: Pattern | WrappedOperator, public params: Param[], public typeAssert: TypeAssert | null, public body: Body | null, diff --git a/src/parser.ts b/src/parser.ts index a45764d8c..0c4cf0785 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -46,6 +46,8 @@ import { IfStatement, MemberExpression, IdentifierAlt, + WrappedOperator, + ArrowTypeExpression, } from "./cst" import { Stream } from "./util"; @@ -160,7 +162,7 @@ export class Parser { return new ReferenceTypeExpression([], name); } - public parseTypeExpression(): TypeExpression { + public parsePrimitiveTypeExpression(): TypeExpression { const t0 = this.peekToken(); switch (t0.kind) { case SyntaxKind.IdentifierAlt: @@ -170,6 +172,24 @@ export class Parser { } } + public parseTypeExpression(): TypeExpression { + let returnType = this.parsePrimitiveTypeExpression(); + const paramTypes = []; + for (;;) { + const t1 = this.peekToken(); + if (t1.kind !== SyntaxKind.RArrow) { + break; + } + this.getToken(); + paramTypes.push(returnType); + returnType = this.parsePrimitiveTypeExpression(); + } + if (paramTypes.length === 0) { + return returnType; + } + return new ArrowTypeExpression(paramTypes, returnType); + } + public parseConstantExpression(): ConstantExpression { const token = this.getToken() if (token.kind !== SyntaxKind.StringLiteral @@ -497,9 +517,9 @@ export class Parser { switch (t0.kind) { case SyntaxKind.LParen: { - this.getToken(); const t1 = this.peekToken(); if (t1.kind === SyntaxKind.IdentifierAlt) { + this.getToken(); return this.parsePatternStartingWithConstructor(); } else { return this.parseTuplePattern(); @@ -551,7 +571,18 @@ export class Parser { this.getToken(); mutKeyword = t1; } - const pattern = this.parsePattern(); + const t2 = this.peekToken(); + const t3 = this.peekToken(2); + const t4 = this.peekToken(3); + let pattern; + if (t2.kind === SyntaxKind.LParen && t3.kind === SyntaxKind.CustomOperator && t4.kind === SyntaxKind.RParen) { + this.getToken() + this.getToken(); + this.getToken(); + pattern = new WrappedOperator(t2, t3, t4); + } else { + pattern = this.parsePattern(); + } const params = []; for (;;) { const t2 = this.peekToken(); @@ -564,14 +595,14 @@ export class Parser { params.push(this.parseParam()); } let typeAssert = null; - let t3 = this.getToken(); - if (t3.kind === SyntaxKind.Colon) { + let t5 = this.getToken(); + if (t5.kind === SyntaxKind.Colon) { const typeExpression = this.parseTypeExpression(); - typeAssert = new TypeAssert(t3, typeExpression); - t3 = this.getToken(); + typeAssert = new TypeAssert(t5, typeExpression); + t5 = this.getToken(); } let body = null; - switch (t3.kind) { + switch (t5.kind) { case SyntaxKind.BlockStart: { const elements = []; @@ -583,22 +614,22 @@ export class Parser { } elements.push(this.parseLetBodyElement()); } - body = new BlockBody(t3, elements); - t3 = this.getToken(); + body = new BlockBody(t5, elements); + t5 = this.getToken(); break; } case SyntaxKind.Equals: { const expression = this.parseExpression(); - body = new ExprBody(t3, expression); - t3 = this.getToken(); + body = new ExprBody(t5, expression); + t5 = this.getToken(); break; } case SyntaxKind.LineFoldEnd: break; } - if (t3.kind !== SyntaxKind.LineFoldEnd) { - this.raiseParseError(t3, [ SyntaxKind.LineFoldEnd ]); + if (t5.kind !== SyntaxKind.LineFoldEnd) { + this.raiseParseError(t5, [ SyntaxKind.LineFoldEnd ]); } return new LetDeclaration( pubKeyword, diff --git a/src/scanner.ts b/src/scanner.ts index 354a3d89a..8b2459eb5 100644 --- a/src/scanner.ts +++ b/src/scanner.ts @@ -35,6 +35,7 @@ import { ElseKeyword, IfKeyword, StructKeyword, + RArrow, } from "./cst" import { Diagnostics, UnexpectedCharDiagnostic } from "./diagnostics" import { Stream, BufferedStream, assert } from "./util"; @@ -246,7 +247,9 @@ export class Scanner extends BufferedStream { case '?': { const text = c0 + this.takeWhile(isOperatorPart); - if (text === '=') { + if (text === '->') { + return new RArrow(startPos); + } else if (text === '=') { return new Equals(startPos); } else if (text.endsWith('=') && text[text.length-2] !== '=') { return new Assignment(text, startPos);