diff --git a/src/analysis.ts b/src/analysis.ts index df247d0d8..47ecac6ff 100644 --- a/src/analysis.ts +++ b/src/analysis.ts @@ -28,6 +28,14 @@ export class Analyser { case SyntaxKind.ConstantExpression: break; + case SyntaxKind.MatchExpression: + { + for (const arm of node.arms) { + visit(arm.expression, source); + } + break; + } + case SyntaxKind.SourceFile: { for (const element of node.elements) { diff --git a/src/checker.ts b/src/checker.ts index ea72a2fcf..786d8713f 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -786,13 +786,16 @@ export class Checker { private createTypeVar(): TVar { const typeVar = new TVar(this.nextTypeVarId++); - const context = this.contexts[this.contexts.length-1]; - context.typeVars.add(typeVar); + this.getContext().typeVars.add(typeVar); return typeVar; } + public getContext(): InferContext { + return this.contexts[this.contexts.length-1]; + } + private addConstraint(constraint: Constraint): void { - this.contexts[this.contexts.length-1].constraints.push(constraint); + this.getContext().constraints.push(constraint); } private pushContext(context: InferContext) { @@ -805,13 +808,12 @@ export class Checker { } private lookup(name: string, kind: Symkind): Scheme | null { - const context = this.contexts[this.contexts.length-1]; - return context.env.lookup(name, kind); + return this.getContext().env.lookup(name, kind); } private getReturnType(): Type { - const context = this.contexts[this.contexts.length-1]; - assert(context && context.returnType !== null); + const context = this.getContext(); + assert(context.returnType !== null); return context.returnType; } @@ -1209,6 +1211,47 @@ export class Checker { case SyntaxKind.NestedExpression: return this.inferExpression(node.expression); + case SyntaxKind.MatchExpression: + { + let exprType; + if (node.expression !== null) { + exprType = this.inferExpression(node.expression); + } else { + exprType = this.createTypeVar(); + } + let resultType: Type = this.createTypeVar(); + for (const arm of node.arms) { + const context = this.getContext(); + const newEnv = new TypeEnv(context.env); + const newContext: InferContext = { + constraints: context.constraints, + typeVars: context.typeVars, + env: newEnv, + returnType: context.returnType, + }; + this.pushContext(newContext); + this.addConstraint( + new CEqual( + this.inferBindings(arm.pattern, [], []), + exprType, + arm.pattern, + ) + ); + this.addConstraint( + new CEqual( + resultType, + this.inferExpression(arm.expression), + arm.expression + ) + ); + this.popContext(newContext); + } + if (node === null) { + resultType = new TArrow([ exprType ], resultType); + } + return resultType; + } + case SyntaxKind.ReferenceExpression: { assert(node.modulePath.length === 0); @@ -1430,6 +1473,42 @@ export class Checker { return type; } + case SyntaxKind.LiteralPattern: + { + let type; + switch (pattern.token.kind) { + case SyntaxKind.Integer: + type = this.getIntType(); + break; + case SyntaxKind.StringLiteral: + type = this.getStringType(); + break; + } + type = type.shallowClone(); + type.node = pattern; + return type; + } + + case SyntaxKind.DisjunctivePattern: + { + const type = this.createTypeVar(); + this.addConstraint( + new CEqual( + this.inferBindings(pattern.left, typeVars, constraints), + type, + pattern.left + ) + ); + this.addConstraint( + new CEqual( + this.inferBindings(pattern.right, typeVars, constraints), + type, + pattern.left + ) + ); + return type; + } + case SyntaxKind.StructPattern: { const scheme = this.lookup(pattern.name.text, Symkind.Type); diff --git a/src/cst.ts b/src/cst.ts index 8e0faccb2..656c698df 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -79,6 +79,8 @@ export const enum SyntaxKind { LBracket, RBracket, RArrow, + RArrowAlt, + VBar, Dot, DotDot, Comma, @@ -117,6 +119,8 @@ export const enum SyntaxKind { StructPattern, NestedPattern, NamedTuplePattern, + LiteralPattern, + DisjunctivePattern, // Struct expression elements StructExpressionField, @@ -128,6 +132,7 @@ export const enum SyntaxKind { VariadicStructPatternElement, // Expressions + MatchExpression, MemberExpression, CallExpression, ReferenceExpression, @@ -168,6 +173,7 @@ export const enum SyntaxKind { // Other nodes WrappedOperator, + MatchArm, Initializer, QualifiedName, TypeAssert, @@ -917,8 +923,30 @@ export class RArrow extends TokenBase { } +export class RArrowAlt extends TokenBase { + + public readonly kind = SyntaxKind.RArrowAlt; + + public get text(): string { + return '=>'; + } + +} + +export class VBar extends TokenBase { + + public readonly kind = SyntaxKind.VBar; + + public get text(): string { + return '|'; + } + +} + export type Token = RArrow + | RArrowAlt + | VBar | LParen | RParen | LBrace @@ -1264,12 +1292,57 @@ export class NestedPattern extends SyntaxBase { } +export class DisjunctivePattern extends SyntaxBase { + + public readonly kind = SyntaxKind.DisjunctivePattern; + + public constructor( + public left: Pattern, + public operator: VBar, + public right: Pattern, + ) { + super(); + } + + public getFirstToken(): Token { + return this.left.getFirstToken(); + } + + public getLastToken(): Token { + return this.right.getLastToken(); + } + +} + + +export class LiteralPattern extends SyntaxBase { + + public readonly kind = SyntaxKind.LiteralPattern; + + public constructor( + public token: StringLiteral | Integer + ) { + super(); + } + + public getFirstToken(): Token { + return this.token; + } + + public getLastToken(): Token { + return this.token; + } + +} + export type Pattern = BindPattern | NestedPattern | StructPattern | NamedTuplePattern | TuplePattern + | DisjunctivePattern + | LiteralPattern export class TupleExpression extends SyntaxBase { @@ -1475,6 +1548,56 @@ export class NamedTupleExpression extends SyntaxBase { } +export class MatchArm extends SyntaxBase { + + public readonly kind = SyntaxKind.MatchArm; + + public constructor( + public pattern: Pattern, + public rarrowAlt: RArrowAlt, + public expression: Expression, + ) { + super(); + } + + public getFirstToken(): Token { + return this.pattern.getFirstToken(); + } + + public getLastToken(): Token { + return this.expression.getLastToken(); + } + +} + +export class MatchExpression extends SyntaxBase { + + public readonly kind = SyntaxKind.MatchExpression; + + public constructor( + public matchKeyword: MatchKeyword, + public expression: Expression | null, + public arms: MatchArm[], + ) { + super(); + } + + public getFirstToken(): Token { + return this.matchKeyword; + } + + public getLastToken(): Token { + if (this.arms.length > 0) { + return this.arms[this.arms.length-1].getLastToken(); + } + if (this.expression !== null) { + return this.expression.getLastToken(); + } + return this.matchKeyword; + } + +} + export class ReferenceExpression extends SyntaxBase { public readonly kind = SyntaxKind.ReferenceExpression; @@ -1592,6 +1715,7 @@ export type Expression | ReferenceExpression | ConstantExpression | TupleExpression + | MatchExpression | NestedExpression | PrefixExpression | InfixExpression diff --git a/src/diagnostics.ts b/src/diagnostics.ts index 232b11ec4..7f80642c2 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -87,6 +87,8 @@ const DESCRIPTIONS: Partial> = { [SyntaxKind.BlockEnd]: 'the end of an indented block', [SyntaxKind.LineFoldEnd]: 'the end of the current line-fold', [SyntaxKind.EndOfFile]: 'end-of-file', + [SyntaxKind.RArrowAlt]: '"=>"', + [SyntaxKind.VBar]: "'|'", } function describeSyntaxKind(kind: SyntaxKind): string { diff --git a/src/parser.ts b/src/parser.ts index 8a5640e11..91da7501e 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -25,7 +25,6 @@ import { TypeAssert, ExprBody, BlockBody, - QualifiedName, NestedExpression, NamedTuplePattern, StructPattern, @@ -36,7 +35,6 @@ import { InfixExpression, TextFile, CallExpression, - NamedTupleExpression, LetBodyElement, ReturnStatement, StructExpression, @@ -56,6 +54,10 @@ import { AppTypeExpression, NestedPattern, NestedTypeExpression, + MatchArm, + MatchExpression, + LiteralPattern, + DisjunctivePattern, } from "./cst" import { Stream } from "./util"; @@ -72,11 +74,13 @@ export class ParseError extends Error { } function isBinaryOperatorLike(token: Token): boolean { - return token.kind === SyntaxKind.CustomOperator; + return token.kind === SyntaxKind.CustomOperator + || token.kind === SyntaxKind.VBar; } function isPrefixOperatorLike(token: Token): boolean { - return token.kind === SyntaxKind.CustomOperator; + return token.kind === SyntaxKind.CustomOperator + || token.kind === SyntaxKind.VBar; } const enum OperatorMode { @@ -295,6 +299,33 @@ export class Parser { case SyntaxKind.Identifier: case SyntaxKind.IdentifierAlt: return this.parseReferenceExpression(); + case SyntaxKind.Integer: + case SyntaxKind.StringLiteral: + return this.parseConstantExpression(); + case SyntaxKind.MatchKeyword: + { + this.getToken(); + let expression = null + const t1 = this.peekToken(); + if (t1.kind !== SyntaxKind.BlockStart) { + expression = this.parseExpression(); + } + this.expectToken(SyntaxKind.BlockStart); + const arms = []; + for (;;) { + const t2 = this.peekToken(); + if (t2.kind === SyntaxKind.BlockEnd) { + this.getToken(); + break; + } + const pattern = this.parsePattern(); + const rarrowAlt = this.expectToken(SyntaxKind.RArrowAlt); + const expression = this.parseExpression(); + arms.push(new MatchArm(pattern, rarrowAlt, expression)); + this.expectToken(SyntaxKind.LineFoldEnd); + } + return new MatchExpression(t0, expression, arms); + } case SyntaxKind.LBrace: { this.getToken(); @@ -335,9 +366,6 @@ export class Parser { } return new StructExpression(t0, fields, rbrace); } - case SyntaxKind.Integer: - case SyntaxKind.StringLiteral: - return this.parseConstantExpression(); default: this.raiseParseError(t0, [ SyntaxKind.NamedTupleExpression, @@ -659,7 +687,7 @@ export class Parser { return new TuplePattern(lparen, elements, rparen); } - public parsePattern(): Pattern { + public parsePrimitivePattern(): Pattern { const t0 = this.peekToken(); switch (t0.kind) { case SyntaxKind.LParen: @@ -684,11 +712,31 @@ export class Parser { this.getToken(); return new BindPattern(t0); } + case SyntaxKind.StringLiteral: + case SyntaxKind.Integer: + { + this.getToken(); + return new LiteralPattern(t0); + } default: this.raiseParseError(t0, [ SyntaxKind.Identifier ]); } } + public parsePattern(): Pattern { + let result: Pattern = this.parsePrimitivePattern(); + for (;;) { + const t1 = this.peekToken(); + if (t1.kind !== SyntaxKind.VBar) { + break; + } + this.getToken(); + const right = this.parsePrimitivePattern(); + result = new DisjunctivePattern(result, t1, right); + } + return result; + } + public parseParam(): Param { const pattern = this.parsePattern(); return new Param(pattern); diff --git a/src/scanner.ts b/src/scanner.ts index a0784536c..6f8a962a1 100644 --- a/src/scanner.ts +++ b/src/scanner.ts @@ -37,6 +37,9 @@ import { StructKeyword, RArrow, EnumKeyword, + MatchKeyword, + RArrowAlt, + VBar, } from "./cst" import { Diagnostics, UnexpectedCharDiagnostic } from "./diagnostics" import { Stream, BufferedStream, assert } from "./util"; @@ -250,6 +253,10 @@ export class Scanner extends BufferedStream { const text = c0 + this.takeWhile(isOperatorPart); if (text === '->') { return new RArrow(startPos); + } else if (text === '=>') { + return new RArrowAlt(startPos); + } else if (text === '|') { + return new VBar(startPos); } else if (text === '=') { return new Equals(startPos); } else if (text.endsWith('=') && text[text.length-2] !== '=') { @@ -358,6 +365,7 @@ export class Scanner extends BufferedStream { case 'elif': return new ElifKeyword(startPos); case 'struct': return new StructKeyword(startPos); case 'enum': return new EnumKeyword(startPos); + case 'match': return new MatchKeyword(startPos); default: if (isUpper(text[0])) { return new IdentifierAlt(text, startPos);