diff --git a/src/checker.ts b/src/checker.ts index 01603f678..b3cf5108f 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -4,13 +4,21 @@ import { Pattern, Scope, SourceFile, - SourceFileElement, - StructDeclaration, + Symkind, Syntax, SyntaxKind, TypeExpression } from "./cst"; -import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, describeType, Diagnostics, FieldDoesNotExistDiagnostic, FieldMissingDiagnostic, UnificationFailedDiagnostic } from "./diagnostics"; +import { + describeType, + ArityMismatchDiagnostic, + BindingNotFoudDiagnostic, + Diagnostics, + FieldDoesNotExistDiagnostic, + FieldMissingDiagnostic, + UnificationFailedDiagnostic, + KindMismatchDiagnostic +} from "./diagnostics"; import { assert, isEmpty } from "./util"; import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl" @@ -30,6 +38,7 @@ export enum TypeKind { Tuple, Labeled, Record, + App, } abstract class TypeBase { @@ -272,7 +281,7 @@ export class TRecord extends TypeBase { public nextRecord: TRecord | null = null; public constructor( - public decl: StructDeclaration, + public decl: Syntax, public fields: Map, public node: Syntax | null = null, ) { @@ -308,6 +317,68 @@ export class TRecord extends TypeBase { } +export class TApp extends TypeBase { + + public readonly kind = TypeKind.App; + + public constructor( + public operatorType: Type, + public argType: Type, + public node: Syntax | null = null + ) { + super(node); + } + + public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): TApp { + let count = argTypes.length; + let result = argTypes[count-1]; + for (let i = count-2; i >= 0; i--) { + result = new TApp(argTypes[i], result, node); + } + return new TApp(operatorType, result, node); + } + + public *getSequence(): Iterable { + if (this.operatorType.kind === TypeKind.App) { + yield* this.operatorType.getSequence(); + } else { + yield this.operatorType; + } + if (this.argType.kind === TypeKind.App) { + yield* this.argType.getSequence(); + } else { + yield this.argType; + } + } + + public *getTypeVars(): Iterable { + yield* this.operatorType.getTypeVars(); + yield* this.argType.getTypeVars(); + } + + public shallowClone() { + return new TApp( + this.operatorType, + this.argType, + this.node + ); + } + + public substitute(sub: TVSub): Type { + let changed = false; + const newOperatorType = this.operatorType.substitute(sub); + if (newOperatorType !== this.operatorType) { + changed = true; + } + const newArgType = this.argType.substitute(sub); + if (newArgType !== this.argType) { + changed = true; + } + return changed ? new TApp(newOperatorType, newArgType, this.node) : this; + } + +} + export type Type = TCon | TArrow @@ -315,6 +386,7 @@ export type Type | TTuple | TLabeled | TRecord + | TApp class TVSet { @@ -597,11 +669,15 @@ export class Checker { return context.returnType; } - private instantiate(scheme: Scheme, node: Syntax | null): Type { + private createSubstitution(scheme: Scheme): TVSub { const sub = new TVSub(); for (const tv of scheme.typeVars) { sub.set(tv, this.createTypeVar()); } + return sub; + } + + private instantiate(scheme: Scheme, node: Syntax | null, sub = this.createSubstitution(scheme)): Type { for (const constraint of scheme.constraints) { const substituted = constraint.substitute(sub); substituted.node = node; @@ -716,6 +792,8 @@ export class Checker { break; } + case SyntaxKind.TypeDeclaration: + case SyntaxKind.EnumDeclaration: case SyntaxKind.StructDeclaration: break; @@ -820,13 +898,20 @@ export class Checker { case SyntaxKind.StructExpression: { - const scheme = this.lookup(node.name.text); - if (scheme === null) { + const scope = node.getScope(); + const decl = scope.lookup(node.name.text, Symkind.Constructor); + if (decl === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); return this.createTypeVar(); } - const recordType = this.instantiate(scheme, node); - assert(recordType.kind === TypeKind.Record); + assert(decl.kind === SyntaxKind.StructDeclaration || decl.kind === SyntaxKind.EnumDeclarationStructElement); + const scheme = decl.scheme; + const sub = this.createSubstitution(scheme); + const declType = this.instantiate(scheme, node, sub); + const argTypes = []; + for (const typeVar of decl.tvs) { + argTypes.push(sub.get(typeVar)!); + } const fields = new Map(); for (const member of node.members) { switch (member.kind) { @@ -852,10 +937,10 @@ export class Checker { throw new Error(`Unexpected ${member}`); } } - const type = new TRecord(recordType.decl, fields, node); + const type = TApp.build(new TRecord(decl, fields, node), argTypes, node); this.addConstraint( new CEqual( - recordType, + TApp.build(declType, argTypes, node), type, node, ) @@ -923,6 +1008,16 @@ export class Checker { return scheme.type; } + case SyntaxKind.AppTypeExpression: + { + const operatorType = this.inferTypeExpression(node.operator); + const argTypes = []; + for (const argTypeExpr of node.args) { + argTypes.push(this.inferTypeExpression(argTypeExpr)); + } + return TApp.build(operatorType, argTypes); + } + case SyntaxKind.ArrowTypeExpression: { const paramTypes = []; @@ -1145,6 +1240,7 @@ export class Checker { break; } + case SyntaxKind.TypeDeclaration: case SyntaxKind.EnumDeclaration: case SyntaxKind.StructDeclaration: break; @@ -1186,6 +1282,7 @@ export class Checker { case SyntaxKind.IfStatement: case SyntaxKind.ReturnStatement: case SyntaxKind.ExpressionStatement: + case SyntaxKind.TypeDeclaration: case SyntaxKind.EnumDeclaration: case SyntaxKind.StructDeclaration: break; @@ -1232,6 +1329,29 @@ export class Checker { break; } + case SyntaxKind.TypeDeclaration: + { + const env = node.typeEnv = new TypeEnv(parentEnv); + const constraints = new ConstraintSet(); + const typeVars = new TVSet(); + const context: InferContext = { + constraints, + typeVars, + env, + returnType: null, + }; + this.pushContext(context); + for (const varExpr of node.typeVars) { + env.add(varExpr.text, new Forall([], [], this.createTypeVar())); + } + const type = this.inferTypeExpression(node.typeExpression); + this.popContext(context); + const scheme = new Forall(typeVars, constraints, type); + parentEnv.add(node.name.text, scheme); + node.scheme = scheme; + break; + } + case SyntaxKind.StructDeclaration: { const env = node.typeEnv = new TypeEnv(parentEnv); @@ -1244,8 +1364,11 @@ export class Checker { returnType: null, }; this.pushContext(context); + const argTypes = []; for (const varExpr of node.typeVars) { - env.add(varExpr.text, new Forall([], [], this.createTypeVar())); + const type = this.createTypeVar(); + env.add(varExpr.text, new Forall([], [], type)); + argTypes.push(type); } const fields = new Map(); if (node.members !== null) { @@ -1254,8 +1377,11 @@ export class Checker { } } this.popContext(context); - const type = new TRecord(node, fields); - parentEnv.add(node.name.text, new Forall(typeVars, constraints, type)); + const type = new TRecord(node, fields, node); + const scheme = new Forall(typeVars, constraints, type); + parentEnv.add(node.name.text, scheme); + node.tvs = argTypes; + node.scheme = scheme; //new Forall(typeVars, constraints, new TApp(type, argTypes)); break; } @@ -1561,6 +1687,23 @@ export class Checker { } } + if (left.kind === TypeKind.App && right.kind === TypeKind.App) { + let leftElements = [...left.getSequence()]; + let rightElements = [...right.getSequence()]; + if (leftElements.length !== rightElements.length) { + this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node)); + return false; + } + const count = leftElements.length; + let success = true; + for (let i = 0; i < count; i++) { + if (!this.unify(leftElements[i], rightElements[i], solution, constraint)) { + success = false; + } + } + return success; + } + if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) { let success = false; // This works like an ordinary union-find algorithm where an additional @@ -1604,12 +1747,12 @@ export class Checker { } remaining.delete(fieldName); } else { - this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName)); + this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName, constraint.node)); success = false; } } for (const fieldName of remaining) { - this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName)); + this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName, constraint.node)); } if (success) { TypeBase.join(left, right); diff --git a/src/cst.ts b/src/cst.ts index 8d9040c1c..0163b3ab6 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -108,6 +108,7 @@ export const enum SyntaxKind { ReferenceTypeExpression, ArrowTypeExpression, VarTypeExpression, + AppTypeExpression, // Patterns BindPattern, @@ -147,14 +148,11 @@ export const enum SyntaxKind { IfStatementCase, // Declarations - VariableDeclaration, - PrefixFuncDecl, - SuffixFuncDecl, LetDeclaration, StructDeclaration, EnumDeclaration, ImportDeclaration, - TypeAliasDeclaration, + TypeDeclaration, // Let declaration body members ExprBody, @@ -185,6 +183,7 @@ export type Syntax | Param | Body | StructDeclarationField + | EnumDeclarationElement | TypeAssert | Declaration | Statement @@ -207,9 +206,16 @@ function isNodeWithScope(node: Syntax): node is NodeWithScope { || node.kind === SyntaxKind.LetDeclaration; } +export const enum Symkind { + Var = 1, + Type = 2, + Constructor = 4, + Any = Var | Type | Constructor +} + export class Scope { - private mapping = new Map(); + private mapping = new Map(); public constructor( public node: NodeWithScope, @@ -228,6 +234,10 @@ export class Scope { return null; } + private add(name: string, node: Syntax, kind: Symkind): void { + this.mapping.set(name, [kind, node]); + } + private scan(node: Syntax): void { switch (node.kind) { case SyntaxKind.SourceFile: @@ -241,9 +251,17 @@ export class Scope { case SyntaxKind.ReturnStatement: case SyntaxKind.IfStatement: break; + case SyntaxKind.TypeDeclaration: + { + this.add(node.name.text, node, Symkind.Type); + break; + } case SyntaxKind.EnumDeclaration: case SyntaxKind.StructDeclaration: + { + this.add(node.name.text, node, Symkind.Constructor); break; + } case SyntaxKind.LetDeclaration: { for (const param of node.params) { @@ -257,7 +275,7 @@ export class Scope { } } else { if (node.pattern.kind === SyntaxKind.WrappedOperator) { - this.mapping.set(node.pattern.operator.text, node); + this.add(node.pattern.operator.text, node, Symkind.Var); } else { this.scanPattern(node.pattern, node); } @@ -273,7 +291,7 @@ export class Scope { switch (node.kind) { case SyntaxKind.BindPattern: { - this.mapping.set(node.name.text, decl); + this.add(node.name.text, decl, Symkind.Var); break; } case SyntaxKind.StructPattern: @@ -287,7 +305,7 @@ export class Scope { } case SyntaxKind.PunnedStructPatternField: { - this.mapping.set(node.name.text, decl); + this.add(node.name.text, decl, Symkind.Var); break; } } @@ -299,12 +317,15 @@ export class Scope { } } - public lookup(name: string): Syntax | null { + public lookup(name: string, expectedKind = Symkind.Any): Syntax | null { let curr: Scope | null = this; do { - const decl = curr.mapping.get(name); - if (decl !== undefined) { - return decl; + const match = curr.mapping.get(name); + if (match !== undefined) { + const [kind, decl] = match; + if (kind & expectedKind) { + return decl; + } } curr = curr.getParent(); } while (curr !== null); @@ -967,6 +988,30 @@ export class ReferenceTypeExpression extends SyntaxBase { } +export class AppTypeExpression extends SyntaxBase { + + public readonly kind = SyntaxKind.AppTypeExpression; + + public constructor( + public operator: TypeExpression, + public args: TypeExpression[], + ) { + super(); + } + + public getFirstToken(): Token { + return this.operator.getFirstToken(); + } + + public getLastToken(): Token { + if (this.args.length > 0) { + return this.args[this.args.length-1].getLastToken(); + } + return this.operator.getLastToken(); + } + +} + export class VarTypeExpression extends SyntaxBase { public readonly kind = SyntaxKind.VarTypeExpression; @@ -991,6 +1036,7 @@ export type TypeExpression = ReferenceTypeExpression | ArrowTypeExpression | VarTypeExpression + | AppTypeExpression export class BindPattern extends SyntaxBase { @@ -1853,6 +1899,34 @@ export class WrappedOperator extends SyntaxBase { } +export class TypeDeclaration extends SyntaxBase { + + public readonly kind = SyntaxKind.TypeDeclaration; + + public constructor( + public pubKeyword: PubKeyword | null, + public typeKeyword: TypeKeyword, + public name: IdentifierAlt, + public typeVars: Identifier[], + public equals: Equals, + public typeExpression: TypeExpression + ) { + super(); + } + + public getFirstToken(): Token { + if (this.pubKeyword !== null) { + return this.pubKeyword; + } + return this.typeKeyword; + } + + public getLastToken(): Token { + return this.typeExpression.getLastToken(); + } + +} + export class LetDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.LetDeclaration; @@ -1923,6 +1997,7 @@ export type Declaration | ImportDeclaration | StructDeclaration | EnumDeclaration + | TypeDeclaration export class Initializer extends SyntaxBase { diff --git a/src/diagnostics.ts b/src/diagnostics.ts index 5ea4c1749..71856d871 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -198,19 +198,25 @@ export function describeType(type: Type): string { } case TypeKind.Record: { - let out = type.decl.name.text + ' { '; - let first = true; - for (const [fieldName, fieldType] of type.fields) { - if (first) first = false; - else out += ', '; - out += fieldName + ': ' + describeType(fieldType); - } - return out + ' }'; + return type.decl.name.text; + // let out = type.decl.name.text + ' { '; + // let first = true; + // for (const [fieldName, fieldType] of type.fields) { + // if (first) first = false; + // else out += ', '; + // out += fieldName + ': ' + describeType(fieldType); + // } + // return out + ' }'; } case TypeKind.Labeled: { + // FIXME may need to include fields that were added during unification return '{ ' + type.name + ': ' + describeType(type.type) + ' }'; } + case TypeKind.App: + { + return describeType(type.operatorType) + ' ' + describeType(type.argType); + } } } @@ -294,6 +300,7 @@ export class FieldMissingDiagnostic { public constructor( public recordType: TRecord, public fieldName: string, + public node: Syntax | null, ) { } @@ -302,8 +309,8 @@ export class FieldMissingDiagnostic { out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET); out.write(`field '${this.fieldName}' is missing from `); out.write(describeType(this.recordType) + '\n\n'); - if (this.recordType.node !== null) { - out.write(printNode(this.recordType.node) + '\n'); + if (this.node !== null) { + out.write(printNode(this.node) + '\n'); } } @@ -316,13 +323,48 @@ export class FieldDoesNotExistDiagnostic { public constructor( public recordType: TRecord, public fieldName: string, + public node: Syntax | null, ) { + } public format(out: IndentWriter): void { out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET); out.write(`field '${this.fieldName}' does not exist on type `); out.write(describeType(this.recordType) + '\n\n'); + if (this.node !== null) { + out.write(printNode(this.node) + '\n'); + } + } + +} + +export class KindMismatchDiagnostic { + + public readonly level = Level.Error; + + public constructor( + public leftSize: number, + public rightSize: number, + public node: Syntax | null, + ) { + + } + + public format(out: IndentWriter): void { + out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET); + out.write(`kind `); + for (let i = 0; i < this.leftSize-1; i++) { + out.write(`* -> `); + } + out.write(`* does not match with `); + for (let i = 0; i < this.rightSize-1; i++) { + out.write(`* -> `); + } + out.write(`*\n\n`); + if (this.node !== null) { + out.write(printNode(this.node) + '\n'); + } } } @@ -335,6 +377,7 @@ export type Diagnostic | ArityMismatchDiagnostic | FieldMissingDiagnostic | FieldDoesNotExistDiagnostic + | KindMismatchDiagnostic export interface Diagnostics { add(diagnostic: Diagnostic): void; diff --git a/src/parser.ts b/src/parser.ts index 67703e6af..27b269a43 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -52,6 +52,8 @@ import { EnumDeclaration, EnumDeclarationTupleElement, VarTypeExpression, + TypeDeclaration, + AppTypeExpression, } from "./cst" import { Stream } from "./util"; @@ -181,8 +183,30 @@ export class Parser { } } + private tryParseAppTypeExpression(): TypeExpression { + const operator = this.parsePrimitiveTypeExpression(); + const args = []; + for (;;) { + const t1 = this.peekToken(); + if (t1.kind === SyntaxKind.RParen + || t1.kind === SyntaxKind.RBrace + || t1.kind === SyntaxKind.RBracket + || t1.kind === SyntaxKind.Equals + || t1.kind === SyntaxKind.BlockStart + || t1.kind === SyntaxKind.LineFoldEnd + || t1.kind === SyntaxKind.RArrow) { + break; + } + args.push(this.parsePrimitiveTypeExpression()); + } + if (args.length === 0) { + return operator; + } + return new AppTypeExpression(operator, args); + } + public parseTypeExpression(): TypeExpression { - let returnType = this.parsePrimitiveTypeExpression(); + let returnType = this.tryParseAppTypeExpression(); const paramTypes = []; for (;;) { const t1 = this.peekToken(); @@ -191,7 +215,7 @@ export class Parser { } this.getToken(); paramTypes.push(returnType); - returnType = this.parsePrimitiveTypeExpression(); + returnType = this.tryParseAppTypeExpression(); } if (paramTypes.length === 0) { return returnType; @@ -417,6 +441,31 @@ export class Parser { return this.parseBinaryOperatorAfterExpr(lhs, 0); } + public parseTypeDeclaration(): TypeDeclaration { + let pubKeyword = null; + let t0 = this.getToken(); + if (t0.kind === SyntaxKind.PubKeyword) { + pubKeyword = t0; + t0 = this.getToken(); + } + if (t0.kind !== SyntaxKind.TypeKeyword) { + this.raiseParseError(t0, [ SyntaxKind.TypeKeyword ]); + } + const name = this.expectToken(SyntaxKind.IdentifierAlt); + const typeVars = []; + let t1 = this.getToken(); + while (t1.kind === SyntaxKind.Identifier) { + typeVars.push(t1); + t1 = this.getToken(); + } + if (t1.kind !== SyntaxKind.Equals) { + this.raiseParseError(t1, [ SyntaxKind.Equals ]); + } + const typeExpr = this.parseTypeExpression(); + this.expectToken(SyntaxKind.LineFoldEnd); + return new TypeDeclaration(pubKeyword, t0, name, typeVars, t1, typeExpr); + } + public parseEnumDeclaration(): EnumDeclaration { let pubKeyword = null; let t0 = this.getToken(); @@ -817,7 +866,9 @@ export class Parser { case SyntaxKind.StructKeyword: return this.parseStructDeclaration(); case SyntaxKind.EnumKeyword: - return this.parseEnumDeclaration(); + return this.parseStructDeclaration(); + case SyntaxKind.TypeKeyword: + return this.parseTypeDeclaration(); case SyntaxKind.IfKeyword: return this.parseIfStatement(); default: