Add support for currying

This commit is contained in:
Sam Vervaeck 2022-09-16 12:00:00 +02:00
parent a94b05df61
commit 1ea65236a5
3 changed files with 33 additions and 78 deletions

View file

@ -1,5 +1,4 @@
import {
Declaration,
EnumDeclaration,
Expression,
LetDeclaration,
@ -13,7 +12,6 @@ import {
} from "./cst";
import {
describeType,
ArityMismatchDiagnostic,
BindingNotFoudDiagnostic,
Diagnostics,
FieldDoesNotExistDiagnostic,
@ -24,8 +22,6 @@ import {
import { assert, isEmpty, MultiMap } from "./util";
import { Analyser } from "./analysis";
// TODO check that the order by which kindArgs are inserted is correct
const MAX_TYPE_ERROR_COUNT = 5;
export enum TypeKind {
@ -107,23 +103,29 @@ export class TArrow extends TypeBase {
public readonly kind = TypeKind.Arrow;
public constructor(
public paramTypes: Type[],
public paramType: Type,
public returnType: Type,
public node: Syntax | null = null,
) {
super();
}
public *getTypeVars(): Iterable<TVar> {
for (const paramType of this.paramTypes) {
yield* paramType.getTypeVars();
public static build(paramTypes: Type[], returnType: Type, node: Syntax | null = null): Type {
let result = returnType;
for (let i = paramTypes.length-1; i >= 0; i--) {
result = new TArrow(paramTypes[i], result, node);
}
return result;
}
public *getTypeVars(): Iterable<TVar> {
yield* this.paramType.getTypeVars();
yield* this.returnType.getTypeVars();
}
public shallowClone(): TArrow {
return new TArrow(
this.paramTypes,
this.paramType,
this.returnType,
this.node,
)
@ -131,19 +133,15 @@ export class TArrow extends TypeBase {
public substitute(sub: TVSub): Type {
let changed = false;
const newParamTypes = [];
for (const paramType of this.paramTypes) {
const newParamType = paramType.substitute(sub);
if (newParamType !== paramType) {
const newParamType = this.paramType.substitute(sub);
if (newParamType !== this.paramType) {
changed = true;
}
newParamTypes.push(newParamType);
}
const newReturnType = this.returnType.substitute(sub);
if (newReturnType !== this.returnType) {
changed = true;
}
return changed ? new TArrow(newParamTypes, newReturnType, this.node) : this;
return changed ? new TArrow(newParamType, newReturnType, this.node) : this;
}
}
@ -1247,7 +1245,7 @@ export class Checker {
this.popContext(newContext);
}
if (node.expression === null) {
resultType = new TArrow([ exprType ], resultType);
resultType = new TArrow(exprType, resultType);
}
return resultType;
}
@ -1298,7 +1296,7 @@ export class Checker {
this.addConstraint(
new CEqual(
opType,
new TArrow(paramTypes, retType),
TArrow.build(paramTypes, retType),
node
)
);
@ -1364,7 +1362,7 @@ export class Checker {
const rightType = this.inferExpression(node.right);
this.addConstraint(
new CEqual(
new TArrow([ leftType, rightType ], retType),
new TArrow(leftType, new TArrow(rightType, retType)),
opType,
node,
),
@ -1445,7 +1443,7 @@ export class Checker {
paramTypes.push(this.inferTypeExpression(paramTypeExpr, introduceTypeVars));
}
const returnType = this.inferTypeExpression(node.returnTypeExpr, introduceTypeVars);
type = new TArrow(paramTypes, returnType, node);
type = TArrow.build(paramTypes, returnType, node);
break;
}
@ -1624,7 +1622,7 @@ export class Checker {
case SyntaxKind.EnumDeclarationTupleElement:
{
const argTypes = member.elements.map(el => this.inferTypeExpression(el));
elementType = new TArrow(argTypes, TApp.build(type, kindArgs));
elementType = TArrow.build(argTypes, TApp.build(type, kindArgs), member);
break;
}
case SyntaxKind.EnumDeclarationStructElement:
@ -1633,7 +1631,7 @@ export class Checker {
for (const field of member.fields) {
fields.set(field.name.text, this.inferTypeExpression(field.typeExpr));
}
elementType = new TArrow([ new TRecord(fields, member) ], TApp.build(type, kindArgs));
elementType = new TArrow(new TRecord(fields, member), TApp.build(type, kindArgs));
break;
}
default:
@ -1700,7 +1698,7 @@ export class Checker {
this.popContext(context);
const type = new TNominal(node);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, new TArrow([ new TRecord(fields, node) ], TApp.build(type, kindArgs))), Symkind.Var);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, new TArrow(new TRecord(fields, node), TApp.build(type, kindArgs))), Symkind.Var);
break;
}
@ -1731,18 +1729,18 @@ export class Checker {
const b = this.createTypeVar();
const f = this.createTypeVar();
env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b)), Symkind.Var);
env.add('$', new Forall([ f, a ], [], TArrow.build([ a, b, a ], b)), Symkind.Var);
env.add('String', new Forall([], [], this.stringType), Symkind.Type);
env.add('Int', new Forall([], [], this.intType), Symkind.Type);
env.add('Bool', new Forall([], [], this.boolType), Symkind.Type);
env.add('True', new Forall([], [], this.boolType), Symkind.Var);
env.add('False', new Forall([], [], this.boolType), Symkind.Var);
env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)), Symkind.Var);
env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)), Symkind.Var);
env.add('+', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('-', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('*', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('/', new Forall([], [], TArrow.build([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('==', new Forall([ a ], [], TArrow.build([ a, a ], this.boolType)), Symkind.Var);
env.add('not', new Forall([], [], new TArrow(this.boolType, this.boolType)), Symkind.Var);
this.initialize(node, env);
@ -1793,7 +1791,7 @@ export class Checker {
paramTypes.push(paramType);
}
let type = new TArrow(paramTypes, returnType);
let type = TArrow.build(paramTypes, returnType, node);
if (node.typeAssert !== null) {
this.addConstraint(
new CEqual(
@ -1958,17 +1956,10 @@ export class Checker {
}
if (left.kind === TypeKind.Arrow && right.kind === TypeKind.Arrow) {
if (left.paramTypes.length !== right.paramTypes.length) {
this.diagnostics.add(new ArityMismatchDiagnostic(left, right));
return false;
}
let success = true;
const count = left.paramTypes.length;
for (let i = 0; i < count; i++) {
if (!unify(left.paramTypes[i], right.paramTypes[i])) {
if (!unify(left.paramType, right.paramType)) {
success = false;
}
}
if (!unify(left.returnType, right.returnType)) {
success = false;
}
@ -1978,10 +1969,6 @@ export class Checker {
return success;
}
if (left.kind === TypeKind.Arrow && left.paramTypes.length === 0) {
return unify(left.returnType, right);
}
if (right.kind === TypeKind.Arrow) {
return unify(right, left);
}

View file

@ -179,15 +179,7 @@ export function describeType(type: Type): string {
return 'a' + type.id;
case TypeKind.Arrow:
{
let out = '(';
let first = true;
for (const paramType of type.paramTypes) {
if (first) first = false;
else out += ', ';
out += describeType(paramType);
}
out += ') -> ' + describeType(type.returnType);
return out;
return describeType(type.paramType) + ' -> ' + describeType(type.returnType);
}
case TypeKind.Tuple:
{
@ -289,28 +281,6 @@ export class UnificationFailedDiagnostic {
}
export class ArityMismatchDiagnostic {
public readonly level = Level.Error;
public constructor(
public left: TArrow,
public right: TArrow,
) {
}
public format(out: IndentWriter): void {
out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET);
out.write(ANSI_FG_GREEN + describeType(this.left) + ANSI_RESET);
out.write(` has ${this.left.paramTypes.length} `);
out.write(this.left.paramTypes.length === 1 ? 'parameter' : 'parameters');
out.write(' while ' + ANSI_FG_GREEN + describeType(this.right) + ANSI_RESET);
out.write(` has ${this.right.paramTypes.length}.\n\n`);
}
}
export class FieldMissingDiagnostic {
public readonly level = Level.Error;
@ -384,7 +354,6 @@ export type Diagnostic
| BindingNotFoudDiagnostic
| UnificationFailedDiagnostic
| UnexpectedTokenDiagnostic
| ArityMismatchDiagnostic
| FieldMissingDiagnostic
| FieldDoesNotExistDiagnostic
| KindMismatchDiagnostic

View file

@ -368,7 +368,6 @@ export class Parser {
}
default:
this.raiseParseError(t0, [
SyntaxKind.NamedTupleExpression,
SyntaxKind.TupleExpression,
SyntaxKind.NestedExpression,
SyntaxKind.ConstantExpression,