Add support for currying

This commit is contained in:
Sam Vervaeck 2022-09-16 12:00:00 +02:00
parent a94b05df61
commit 1ea65236a5
3 changed files with 33 additions and 78 deletions

View file

@ -1,5 +1,4 @@
import { import {
Declaration,
EnumDeclaration, EnumDeclaration,
Expression, Expression,
LetDeclaration, LetDeclaration,
@ -13,7 +12,6 @@ import {
} from "./cst"; } from "./cst";
import { import {
describeType, describeType,
ArityMismatchDiagnostic,
BindingNotFoudDiagnostic, BindingNotFoudDiagnostic,
Diagnostics, Diagnostics,
FieldDoesNotExistDiagnostic, FieldDoesNotExistDiagnostic,
@ -24,8 +22,6 @@ import {
import { assert, isEmpty, MultiMap } from "./util"; import { assert, isEmpty, MultiMap } from "./util";
import { Analyser } from "./analysis"; import { Analyser } from "./analysis";
// TODO check that the order by which kindArgs are inserted is correct
const MAX_TYPE_ERROR_COUNT = 5; const MAX_TYPE_ERROR_COUNT = 5;
export enum TypeKind { export enum TypeKind {
@ -107,23 +103,29 @@ export class TArrow extends TypeBase {
public readonly kind = TypeKind.Arrow; public readonly kind = TypeKind.Arrow;
public constructor( public constructor(
public paramTypes: Type[], public paramType: Type,
public returnType: Type, public returnType: Type,
public node: Syntax | null = null, public node: Syntax | null = null,
) { ) {
super(); super();
} }
public *getTypeVars(): Iterable<TVar> { public static build(paramTypes: Type[], returnType: Type, node: Syntax | null = null): Type {
for (const paramType of this.paramTypes) { let result = returnType;
yield* paramType.getTypeVars(); for (let i = paramTypes.length-1; i >= 0; i--) {
result = new TArrow(paramTypes[i], result, node);
} }
return result;
}
public *getTypeVars(): Iterable<TVar> {
yield* this.paramType.getTypeVars();
yield* this.returnType.getTypeVars(); yield* this.returnType.getTypeVars();
} }
public shallowClone(): TArrow { public shallowClone(): TArrow {
return new TArrow( return new TArrow(
this.paramTypes, this.paramType,
this.returnType, this.returnType,
this.node, this.node,
) )
@ -131,19 +133,15 @@ export class TArrow extends TypeBase {
public substitute(sub: TVSub): Type { public substitute(sub: TVSub): Type {
let changed = false; let changed = false;
const newParamTypes = []; const newParamType = this.paramType.substitute(sub);
for (const paramType of this.paramTypes) { if (newParamType !== this.paramType) {
const newParamType = paramType.substitute(sub);
if (newParamType !== paramType) {
changed = true; changed = true;
} }
newParamTypes.push(newParamType);
}
const newReturnType = this.returnType.substitute(sub); const newReturnType = this.returnType.substitute(sub);
if (newReturnType !== this.returnType) { if (newReturnType !== this.returnType) {
changed = true; changed = true;
} }
return changed ? new TArrow(newParamTypes, newReturnType, this.node) : this; return changed ? new TArrow(newParamType, newReturnType, this.node) : this;
} }
} }
@ -1247,7 +1245,7 @@ export class Checker {
this.popContext(newContext); this.popContext(newContext);
} }
if (node.expression === null) { if (node.expression === null) {
resultType = new TArrow([ exprType ], resultType); resultType = new TArrow(exprType, resultType);
} }
return resultType; return resultType;
} }
@ -1298,7 +1296,7 @@ export class Checker {
this.addConstraint( this.addConstraint(
new CEqual( new CEqual(
opType, opType,
new TArrow(paramTypes, retType), TArrow.build(paramTypes, retType),
node node
) )
); );
@ -1364,7 +1362,7 @@ export class Checker {
const rightType = this.inferExpression(node.right); const rightType = this.inferExpression(node.right);
this.addConstraint( this.addConstraint(
new CEqual( new CEqual(
new TArrow([ leftType, rightType ], retType), new TArrow(leftType, new TArrow(rightType, retType)),
opType, opType,
node, node,
), ),
@ -1445,7 +1443,7 @@ export class Checker {
paramTypes.push(this.inferTypeExpression(paramTypeExpr, introduceTypeVars)); paramTypes.push(this.inferTypeExpression(paramTypeExpr, introduceTypeVars));
} }
const returnType = this.inferTypeExpression(node.returnTypeExpr, introduceTypeVars); const returnType = this.inferTypeExpression(node.returnTypeExpr, introduceTypeVars);
type = new TArrow(paramTypes, returnType, node); type = TArrow.build(paramTypes, returnType, node);
break; break;
} }
@ -1624,7 +1622,7 @@ export class Checker {
case SyntaxKind.EnumDeclarationTupleElement: case SyntaxKind.EnumDeclarationTupleElement:
{ {
const argTypes = member.elements.map(el => this.inferTypeExpression(el)); const argTypes = member.elements.map(el => this.inferTypeExpression(el));
elementType = new TArrow(argTypes, TApp.build(type, kindArgs)); elementType = TArrow.build(argTypes, TApp.build(type, kindArgs), member);
break; break;
} }
case SyntaxKind.EnumDeclarationStructElement: case SyntaxKind.EnumDeclarationStructElement:
@ -1633,7 +1631,7 @@ export class Checker {
for (const field of member.fields) { for (const field of member.fields) {
fields.set(field.name.text, this.inferTypeExpression(field.typeExpr)); fields.set(field.name.text, this.inferTypeExpression(field.typeExpr));
} }
elementType = new TArrow([ new TRecord(fields, member) ], TApp.build(type, kindArgs)); elementType = new TArrow(new TRecord(fields, member), TApp.build(type, kindArgs));
break; break;
} }
default: default:
@ -1700,7 +1698,7 @@ export class Checker {
this.popContext(context); this.popContext(context);
const type = new TNominal(node); const type = new TNominal(node);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type); parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, new TArrow([ new TRecord(fields, node) ], TApp.build(type, kindArgs))), Symkind.Var); parentEnv.add(node.name.text, new Forall(typeVars, constraints, new TArrow(new TRecord(fields, node), TApp.build(type, kindArgs))), Symkind.Var);
break; break;
} }
@ -1731,18 +1729,18 @@ export class Checker {
const b = this.createTypeVar(); const b = this.createTypeVar();
const f = this.createTypeVar(); const f = this.createTypeVar();
env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b)), Symkind.Var); env.add('$', new Forall([ f, a ], [], TArrow.build([ a, b, a ], b)), Symkind.Var);
env.add('String', new Forall([], [], this.stringType), Symkind.Type); env.add('String', new Forall([], [], this.stringType), Symkind.Type);
env.add('Int', new Forall([], [], this.intType), Symkind.Type); env.add('Int', new Forall([], [], this.intType), Symkind.Type);
env.add('Bool', new Forall([], [], this.boolType), Symkind.Type); env.add('Bool', new Forall([], [], this.boolType), Symkind.Type);
env.add('True', new Forall([], [], this.boolType), Symkind.Var); env.add('True', new Forall([], [], this.boolType), Symkind.Var);
env.add('False', new Forall([], [], this.boolType), Symkind.Var); env.add('False', new Forall([], [], this.boolType), Symkind.Var);
env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); env.add('+', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); env.add('-', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); env.add('*', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); env.add('/', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)), Symkind.Var); env.add('==', new Forall([ a ], [], TArrow.build([ a, a ], this.boolType)), Symkind.Var);
env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)), Symkind.Var); env.add('not', new Forall([], [], new TArrow(this.boolType, this.boolType)), Symkind.Var);
this.initialize(node, env); this.initialize(node, env);
@ -1793,7 +1791,7 @@ export class Checker {
paramTypes.push(paramType); paramTypes.push(paramType);
} }
let type = new TArrow(paramTypes, returnType); let type = TArrow.build(paramTypes, returnType, node);
if (node.typeAssert !== null) { if (node.typeAssert !== null) {
this.addConstraint( this.addConstraint(
new CEqual( new CEqual(
@ -1958,17 +1956,10 @@ export class Checker {
} }
if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) { if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) {
if (left.paramTypes.length !== right.paramTypes.length) {
this.diagnostics.add(new ArityMismatchDiagnostic(left, right));
return false;
}
let success = true; let success = true;
const count = left.paramTypes.length; if (!unify(left.paramType, right.paramType)) {
for (let i = 0; i < count; i++) {
if (!unify(left.paramTypes[i], right.paramTypes[i])) {
success = false; success = false;
} }
}
if (!unify(left.returnType, right.returnType)) { if (!unify(left.returnType, right.returnType)) {
success = false; success = false;
} }
@ -1978,10 +1969,6 @@ export class Checker {
return success; return success;
} }
if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) {
return unify(left.returnType, right);
}
if (right.kind === TypeKind.Arrow) { if (right.kind === TypeKind.Arrow) {
return unify(right, left); return unify(right, left);
} }

View file

@ -179,15 +179,7 @@ export function describeType(type: Type): string {
return 'a' + type.id; return 'a' + type.id;
case TypeKind.Arrow: case TypeKind.Arrow:
{ {
let out = '('; return describeType(type.paramType) + ' -> ' + describeType(type.returnType);
let first = true;
for (const paramType of type.paramTypes) {
if (first) first = false;
else out += ', ';
out += describeType(paramType);
}
out += ') -> ' + describeType(type.returnType);
return out;
} }
case TypeKind.Tuple: case TypeKind.Tuple:
{ {
@ -289,28 +281,6 @@ export class UnificationFailedDiagnostic {
} }
export class ArityMismatchDiagnostic {
public readonly level = Level.Error;
public constructor(
public left: TArrow,
public right: TArrow,
) {
}
public format(out: IndentWriter): void {
out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET);
out.write(ANSI_FG_GREEN + describeType(this.left) + ANSI_RESET);
out.write(` has ${this.left.paramTypes.length} `);
out.write(this.left.paramTypes.length === 1 ? 'parameter' : 'parameters');
out.write(' while ' + ANSI_FG_GREEN + describeType(this.right) + ANSI_RESET);
out.write(` has ${this.right.paramTypes.length}.\n\n`);
}
}
export class FieldMissingDiagnostic { export class FieldMissingDiagnostic {
public readonly level = Level.Error; public readonly level = Level.Error;
@ -384,7 +354,6 @@ export type Diagnostic
| BindingNotFoudDiagnostic | BindingNotFoudDiagnostic
| UnificationFailedDiagnostic | UnificationFailedDiagnostic
| UnexpectedTokenDiagnostic | UnexpectedTokenDiagnostic
| ArityMismatchDiagnostic
| FieldMissingDiagnostic | FieldMissingDiagnostic
| FieldDoesNotExistDiagnostic | FieldDoesNotExistDiagnostic
| KindMismatchDiagnostic | KindMismatchDiagnostic

View file

@ -368,7 +368,6 @@ export class Parser {
} }
default: default:
this.raiseParseError(t0, [ this.raiseParseError(t0, [
SyntaxKind.NamedTupleExpression,
SyntaxKind.TupleExpression, SyntaxKind.TupleExpression,
SyntaxKind.NestedExpression, SyntaxKind.NestedExpression,
SyntaxKind.ConstantExpression, SyntaxKind.ConstantExpression,