From 1ea65236a5ee4aafa548eb77bb907661f19b823d Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Fri, 16 Sep 2022 12:00:00 +0200 Subject: [PATCH] Add support for currying --- src/checker.ts | 77 +++++++++++++++++++--------------------------- src/diagnostics.ts | 33 +------------------- src/parser.ts | 1 - 3 files changed, 33 insertions(+), 78 deletions(-) diff --git a/src/checker.ts b/src/checker.ts index ce55a4750..cd657e9e5 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -1,5 +1,4 @@ import { - Declaration, EnumDeclaration, Expression, LetDeclaration, @@ -13,7 +12,6 @@ import { } from "./cst"; import { describeType, - ArityMismatchDiagnostic, BindingNotFoudDiagnostic, Diagnostics, FieldDoesNotExistDiagnostic, @@ -24,8 +22,6 @@ import { import { assert, isEmpty, MultiMap } from "./util"; import { Analyser } from "./analysis"; -// TODO check that the order by which kindArgs are inserted is correct - const MAX_TYPE_ERROR_COUNT = 5; export enum TypeKind { @@ -107,23 +103,29 @@ export class TArrow extends TypeBase { public readonly kind = TypeKind.Arrow; public constructor( - public paramTypes: Type[], + public paramType: Type, public returnType: Type, public node: Syntax | null = null, ) { super(); } - public *getTypeVars(): Iterable { - for (const paramType of this.paramTypes) { - yield* paramType.getTypeVars(); + public static build(paramTypes: Type[], returnType: Type, node: Syntax | null = null): Type { + let result = returnType; + for (let i = paramTypes.length-1; i >= 0; i--) { + result = new TArrow(paramTypes[i], result, node); } + return result; + } + + public *getTypeVars(): Iterable { + yield* this.paramType.getTypeVars(); yield* this.returnType.getTypeVars(); } public shallowClone(): TArrow { return new TArrow( - this.paramTypes, + this.paramType, this.returnType, this.node, ) @@ -131,19 +133,15 @@ export class TArrow extends TypeBase { public substitute(sub: TVSub): Type { let changed = false; - const newParamTypes = []; - for (const paramType of this.paramTypes) { - const newParamType = paramType.substitute(sub); - if (newParamType !== paramType) { - changed = true; - } - newParamTypes.push(newParamType); + const newParamType = this.paramType.substitute(sub); + if (newParamType !== this.paramType) { + changed = true; } const newReturnType = this.returnType.substitute(sub); if (newReturnType !== this.returnType) { changed = true; } - return changed ? new TArrow(newParamTypes, newReturnType, this.node) : this; + return changed ? new TArrow(newParamType, newReturnType, this.node) : this; } } @@ -1247,7 +1245,7 @@ export class Checker { this.popContext(newContext); } if (node.expression === null) { - resultType = new TArrow([ exprType ], resultType); + resultType = new TArrow(exprType, resultType); } return resultType; } @@ -1298,7 +1296,7 @@ export class Checker { this.addConstraint( new CEqual( opType, - new TArrow(paramTypes, retType), + TArrow.build(paramTypes, retType), node ) ); @@ -1364,7 +1362,7 @@ export class Checker { const rightType = this.inferExpression(node.right); this.addConstraint( new CEqual( - new TArrow([ leftType, rightType ], retType), + new TArrow(leftType, new TArrow(rightType, retType)), opType, node, ), @@ -1445,7 +1443,7 @@ export class Checker { paramTypes.push(this.inferTypeExpression(paramTypeExpr, introduceTypeVars)); } const returnType = this.inferTypeExpression(node.returnTypeExpr, introduceTypeVars); - type = new TArrow(paramTypes, returnType, node); + type = TArrow.build(paramTypes, returnType, node); break; } @@ -1624,7 +1622,7 @@ export class Checker { case SyntaxKind.EnumDeclarationTupleElement: { const argTypes = member.elements.map(el => this.inferTypeExpression(el)); - elementType = new TArrow(argTypes, TApp.build(type, kindArgs)); + elementType = TArrow.build(argTypes, TApp.build(type, kindArgs), member); break; } case SyntaxKind.EnumDeclarationStructElement: @@ -1633,7 +1631,7 @@ export class Checker { for (const field of member.fields) { fields.set(field.name.text, this.inferTypeExpression(field.typeExpr)); } - elementType = new TArrow([ new TRecord(fields, member) ], TApp.build(type, kindArgs)); + elementType = new TArrow(new TRecord(fields, member), TApp.build(type, kindArgs)); break; } default: @@ -1700,7 +1698,7 @@ export class Checker { this.popContext(context); const type = new TNominal(node); parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type); - parentEnv.add(node.name.text, new Forall(typeVars, constraints, new TArrow([ new TRecord(fields, node) ], TApp.build(type, kindArgs))), Symkind.Var); + parentEnv.add(node.name.text, new Forall(typeVars, constraints, new TArrow(new TRecord(fields, node), TApp.build(type, kindArgs))), Symkind.Var); break; } @@ -1731,18 +1729,18 @@ export class Checker { const b = this.createTypeVar(); const f = this.createTypeVar(); - env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b)), Symkind.Var); + env.add('$', new Forall([ f, a ], [], TArrow.build([ a, b, a ], b)), Symkind.Var); env.add('String', new Forall([], [], this.stringType), Symkind.Type); env.add('Int', new Forall([], [], this.intType), Symkind.Type); env.add('Bool', new Forall([], [], this.boolType), Symkind.Type); env.add('True', new Forall([], [], this.boolType), Symkind.Var); env.add('False', new Forall([], [], this.boolType), Symkind.Var); - env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); - env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); - env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); - env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); - env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)), Symkind.Var); - env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)), Symkind.Var); + env.add('+', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var); + env.add('-', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var); + env.add('*', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var); + env.add('/', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var); + env.add('==', new Forall([ a ], [], TArrow.build([ a, a ], this.boolType)), Symkind.Var); + env.add('not', new Forall([], [], new TArrow(this.boolType, this.boolType)), Symkind.Var); this.initialize(node, env); @@ -1793,7 +1791,7 @@ export class Checker { paramTypes.push(paramType); } - let type = new TArrow(paramTypes, returnType); + let type = TArrow.build(paramTypes, returnType, node); if (node.typeAssert !== null) { this.addConstraint( new CEqual( @@ -1958,16 +1956,9 @@ export class Checker { } if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) { - if (left.paramTypes.length !== right.paramTypes.length) { - this.diagnostics.add(new ArityMismatchDiagnostic(left, right)); - return false; - } let success = true; - const count = left.paramTypes.length; - for (let i = 0; i < count; i++) { - if (!unify(left.paramTypes[i], right.paramTypes[i])) { - success = false; - } + if (!unify(left.paramType, right.paramType)) { + success = false; } if (!unify(left.returnType, right.returnType)) { success = false; @@ -1978,10 +1969,6 @@ export class Checker { return success; } - if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) { - return unify(left.returnType, right); - } - if (right.kind === TypeKind.Arrow) { return unify(right, left); } diff --git a/src/diagnostics.ts b/src/diagnostics.ts index f8dfc7da2..303e1436b 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -179,15 +179,7 @@ export function describeType(type: Type): string { return 'a' + type.id; case TypeKind.Arrow: { - let out = '('; - let first = true; - for (const paramType of type.paramTypes) { - if (first) first = false; - else out += ', '; - out += describeType(paramType); - } - out += ') -> ' + describeType(type.returnType); - return out; + return describeType(type.paramType) + ' -> ' + describeType(type.returnType); } case TypeKind.Tuple: { @@ -289,28 +281,6 @@ export class UnificationFailedDiagnostic { } -export class ArityMismatchDiagnostic { - - public readonly level = Level.Error; - - public constructor( - public left: TArrow, - public right: TArrow, - ) { - - } - - public format(out: IndentWriter): void { - out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET); - out.write(ANSI_FG_GREEN + describeType(this.left) + ANSI_RESET); - out.write(` has ${this.left.paramTypes.length} `); - out.write(this.left.paramTypes.length === 1 ? 'parameter' : 'parameters'); - out.write(' while ' + ANSI_FG_GREEN + describeType(this.right) + ANSI_RESET); - out.write(` has ${this.right.paramTypes.length}.\n\n`); - } - -} - export class FieldMissingDiagnostic { public readonly level = Level.Error; @@ -384,7 +354,6 @@ export type Diagnostic | BindingNotFoudDiagnostic | UnificationFailedDiagnostic | UnexpectedTokenDiagnostic - | ArityMismatchDiagnostic | FieldMissingDiagnostic | FieldDoesNotExistDiagnostic | KindMismatchDiagnostic diff --git a/src/parser.ts b/src/parser.ts index 91da7501e..9b822611e 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -368,7 +368,6 @@ export class Parser { } default: this.raiseParseError(t0, [ - SyntaxKind.NamedTupleExpression, SyntaxKind.TupleExpression, SyntaxKind.NestedExpression, SyntaxKind.ConstantExpression,