Many fixes and add better support for enum-declarations
This commit is contained in:
parent
4cc2b23109
commit
c5fe5004b6
3 changed files with 155 additions and 171 deletions
310
src/checker.ts
310
src/checker.ts
|
@ -1,4 +1,6 @@
|
|||
import {
|
||||
EnumDeclaration,
|
||||
EnumDeclarationElement,
|
||||
Expression,
|
||||
LetDeclaration,
|
||||
Pattern,
|
||||
|
@ -280,7 +282,7 @@ export class TRecord extends TypeBase {
|
|||
|
||||
public constructor(
|
||||
public decl: Syntax,
|
||||
public typeVars: TVar[],
|
||||
public kindArgs: TVar[],
|
||||
public fields: Map<string, Type>,
|
||||
public node: Syntax | null = null,
|
||||
) {
|
||||
|
@ -296,7 +298,7 @@ export class TRecord extends TypeBase {
|
|||
public shallowClone(): TRecord {
|
||||
return new TRecord(
|
||||
this.decl,
|
||||
this.typeVars,
|
||||
this.kindArgs,
|
||||
this.fields,
|
||||
this.node
|
||||
);
|
||||
|
@ -305,7 +307,7 @@ export class TRecord extends TypeBase {
|
|||
public substitute(sub: TVSub): Type {
|
||||
let changed = false;
|
||||
const newTypeVars = [];
|
||||
for (const typeVar of this.typeVars) {
|
||||
for (const typeVar of this.kindArgs) {
|
||||
const newTypeVar = typeVar.substitute(sub);
|
||||
assert(newTypeVar.kind === TypeKind.Var);
|
||||
if (newTypeVar !== typeVar) {
|
||||
|
@ -331,59 +333,42 @@ export class TApp extends TypeBase {
|
|||
public readonly kind = TypeKind.App;
|
||||
|
||||
public constructor(
|
||||
public operatorType: Type,
|
||||
public argType: Type,
|
||||
public left: Type,
|
||||
public right: Type,
|
||||
public node: Syntax | null = null
|
||||
) {
|
||||
super(node);
|
||||
}
|
||||
|
||||
public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): Type {
|
||||
if (argTypes.length === 0) {
|
||||
return operatorType;
|
||||
}
|
||||
let count = argTypes.length;
|
||||
let result = argTypes[count-1];
|
||||
for (let i = count-2; i >= 0; i--) {
|
||||
result = new TApp(argTypes[i], result, node);
|
||||
}
|
||||
return new TApp(operatorType, result, node);
|
||||
}
|
||||
|
||||
public *getSequence(): Iterable<Type> {
|
||||
if (this.operatorType.kind === TypeKind.App) {
|
||||
yield* this.operatorType.getSequence();
|
||||
} else {
|
||||
yield this.operatorType;
|
||||
}
|
||||
if (this.argType.kind === TypeKind.App) {
|
||||
yield* this.argType.getSequence();
|
||||
} else {
|
||||
yield this.argType;
|
||||
public static build(types: Type[]): Type {
|
||||
let result = types[0];
|
||||
for (let i = 1; i < types.length; i++) {
|
||||
result = new TApp(result, types[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public *getTypeVars(): Iterable<TVar> {
|
||||
yield* this.operatorType.getTypeVars();
|
||||
yield* this.argType.getTypeVars();
|
||||
yield* this.left.getTypeVars();
|
||||
yield* this.right.getTypeVars();
|
||||
}
|
||||
|
||||
public shallowClone() {
|
||||
return new TApp(
|
||||
this.operatorType,
|
||||
this.argType,
|
||||
this.left,
|
||||
this.right,
|
||||
this.node
|
||||
);
|
||||
}
|
||||
|
||||
public substitute(sub: TVSub): Type {
|
||||
let changed = false;
|
||||
const newOperatorType = this.operatorType.substitute(sub);
|
||||
if (newOperatorType !== this.operatorType) {
|
||||
const newOperatorType = this.left.substitute(sub);
|
||||
if (newOperatorType !== this.left) {
|
||||
changed = true;
|
||||
}
|
||||
const newArgType = this.argType.substitute(sub);
|
||||
if (newArgType !== this.argType) {
|
||||
const newArgType = this.right.substitute(sub);
|
||||
if (newArgType !== this.right) {
|
||||
changed = true;
|
||||
}
|
||||
return changed ? new TApp(newOperatorType, newArgType, this.node) : this;
|
||||
|
@ -396,7 +381,8 @@ export class TVariant extends TypeBase {
|
|||
public readonly kind = TypeKind.Variant;
|
||||
|
||||
public constructor(
|
||||
public typeVars: TVar[],
|
||||
public decl: Syntax,
|
||||
public kindArgs: Type[],
|
||||
public elementTypes: Type[],
|
||||
public node: Syntax | null = null,
|
||||
) {
|
||||
|
@ -411,7 +397,8 @@ export class TVariant extends TypeBase {
|
|||
|
||||
public shallowClone(): Type {
|
||||
return new TVariant(
|
||||
this.typeVars,
|
||||
this.decl,
|
||||
this.kindArgs,
|
||||
this.elementTypes,
|
||||
this.node,
|
||||
);
|
||||
|
@ -420,10 +407,10 @@ export class TVariant extends TypeBase {
|
|||
public substitute(sub: TVSub): Type {
|
||||
let changed = false;
|
||||
const newTypeVars = [];
|
||||
for (const typeVar of this.typeVars) {
|
||||
const newTypeVar = typeVar.substitute(sub);
|
||||
for (const kindArg of this.kindArgs) {
|
||||
const newTypeVar = kindArg.substitute(sub);
|
||||
assert(newTypeVar.kind === TypeKind.Var);
|
||||
if (newTypeVar !== typeVar) {
|
||||
if (newTypeVar !== kindArg) {
|
||||
changed = true;
|
||||
}
|
||||
newTypeVars.push(newTypeVar);
|
||||
|
@ -436,7 +423,7 @@ export class TVariant extends TypeBase {
|
|||
}
|
||||
newElementTypes.push(newElementType);
|
||||
}
|
||||
return changed ? new TVariant(newTypeVars, newElementTypes, this.node) : this;
|
||||
return changed ? new TVariant(this.decl, newTypeVars, newElementTypes, this.node) : this;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -904,11 +891,9 @@ export class Checker {
|
|||
}
|
||||
case SyntaxKind.AppTypeExpression:
|
||||
{
|
||||
let operator = this.inferKindFromTypeExpression(node.operator, env);
|
||||
const args = node.args.map(arg => this.inferKindFromTypeExpression(arg, env));
|
||||
let result = operator;
|
||||
for (const arg of args) {
|
||||
result = this.applyKind(result, arg, node);
|
||||
let result = this.inferKindFromTypeExpression(node.operator, env);;
|
||||
for (const arg of node.args) {
|
||||
result = this.applyKind(result, this.inferKindFromTypeExpression(arg, env), node);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -1032,6 +1017,19 @@ export class Checker {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case SyntaxKind.LetDeclaration:
|
||||
{
|
||||
if (node.typeAssert !== null) {
|
||||
this.unifyKind(this.inferKindFromTypeExpression(node.typeAssert.typeExpression, env), new KStar(), node.typeAssert.typeExpression);
|
||||
}
|
||||
if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) {
|
||||
for (const element of node.body.elements) {
|
||||
// TODO fork `env` to support local type declarations
|
||||
this.inferKind(element, env);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1069,34 +1067,6 @@ export class Checker {
|
|||
if (a.type === KindType.Arrow && b.type === KindType.Arrow) {
|
||||
return this.unifyKind(a.left, b.left, node)
|
||||
|| this.unifyKind(a.right, b.right, node);
|
||||
// let success = true;
|
||||
// const leftStack = [];
|
||||
// const rightStack = [];
|
||||
// let leftCurr: Kind = a;
|
||||
// let rightCurr: Kind = b;
|
||||
// for (;;) {
|
||||
// while (leftCurr.type === KindType.Arrow) {
|
||||
// leftStack.push(leftCurr);
|
||||
// leftCurr = find(leftCurr.left);
|
||||
// }
|
||||
// while (rightCurr.type === KindType.Arrow) {
|
||||
// rightStack.push(rightCurr);
|
||||
// rightCurr = find(rightCurr.left);
|
||||
// }
|
||||
// if (!this.unifyKind(leftCurr, rightCurr, node)) {
|
||||
// success = false;
|
||||
// }
|
||||
// if (leftStack.length === 0 || rightStack.length === 0) {
|
||||
// if (leftStack.length > 0 || rightStack.length > 0) {
|
||||
// this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node));
|
||||
// success = false;
|
||||
// }
|
||||
// break;
|
||||
// }
|
||||
// rightCurr = find(rightStack.pop()!.right);
|
||||
// leftCurr = find(leftStack.pop()!.right);
|
||||
// }
|
||||
// return success;
|
||||
}
|
||||
|
||||
this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node));
|
||||
|
@ -1215,6 +1185,27 @@ export class Checker {
|
|||
|
||||
}
|
||||
|
||||
private buildVariantType(decl: EnumDeclarationElement, type: Type): Type {
|
||||
const enumDecl = decl.parent as EnumDeclaration;
|
||||
const kindArgs = [];
|
||||
for (const _ of enumDecl.varExps) {
|
||||
kindArgs.push(this.createTypeVar());
|
||||
}
|
||||
const variantTypes = [];
|
||||
if (enumDecl.members !== null) {
|
||||
for (const member of enumDecl.members) {
|
||||
let variantType;
|
||||
if (member === decl) {
|
||||
variantType = type;
|
||||
} else {
|
||||
variantType = this.createTypeVar();
|
||||
}
|
||||
variantTypes.push(variantType);
|
||||
}
|
||||
}
|
||||
return TApp.build([ ...kindArgs, new TVariant(enumDecl, [], []) ]);
|
||||
}
|
||||
|
||||
public inferExpression(node: Expression): Type {
|
||||
|
||||
switch (node.kind) {
|
||||
|
@ -1293,18 +1284,27 @@ export class Checker {
|
|||
|
||||
case SyntaxKind.NamedTupleExpression:
|
||||
{
|
||||
// TODO Only lookup constructors and skip other bindings
|
||||
const scheme = this.lookup(node.name.text);
|
||||
if (scheme === null) {
|
||||
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
|
||||
return this.createTypeVar();
|
||||
}
|
||||
const type = this.instantiate(scheme, node.name);
|
||||
assert(type.kind === TypeKind.Con);
|
||||
const argTypes = [];
|
||||
for (const element of node.elements) {
|
||||
argTypes.push(this.inferExpression(element));
|
||||
}
|
||||
return new TCon(type.id, argTypes, type.displayName, node);
|
||||
const operatorType = this.instantiate(scheme, node.name);
|
||||
const argTypes = node.elements.map(el => this.inferExpression(el));
|
||||
const retType = this.createTypeVar();
|
||||
this.addConstraint(
|
||||
new CEqual(
|
||||
new TArrow(
|
||||
argTypes,
|
||||
retType,
|
||||
node,
|
||||
),
|
||||
operatorType,
|
||||
node
|
||||
)
|
||||
);
|
||||
return retType;
|
||||
}
|
||||
|
||||
case SyntaxKind.StructExpression:
|
||||
|
@ -1350,19 +1350,9 @@ export class Checker {
|
|||
throw new Error(`Unexpected ${member}`);
|
||||
}
|
||||
}
|
||||
let type = new TRecord(decl, argTypes, fields, node);
|
||||
let type: Type = new TRecord(decl, argTypes, fields, node);
|
||||
if (decl.kind === SyntaxKind.EnumDeclarationStructElement) {
|
||||
const elementTypes = [];
|
||||
for (const element of decl.parent!.elements) {
|
||||
let elementType;
|
||||
if (element === decl) {
|
||||
elementType = type;
|
||||
} else {
|
||||
elementType = this.createTypeVar();
|
||||
}
|
||||
elementTypes.push(elementType);
|
||||
}
|
||||
type = new TVariant(typeVars, elementTypes);
|
||||
type = this.buildVariantType(decl, type);
|
||||
}
|
||||
this.addConstraint(
|
||||
new CEqual(
|
||||
|
@ -1414,10 +1404,14 @@ export class Checker {
|
|||
return this.createTypeVar();
|
||||
}
|
||||
const type = this.instantiate(scheme, node.name);
|
||||
// FIXME it is not guaranteed that `type` is copied, so the original type might get mutated
|
||||
type.node = node;
|
||||
return type;
|
||||
}
|
||||
|
||||
case SyntaxKind.NestedTypeExpression:
|
||||
return this.inferTypeExpression(node.typeExpr, introduceTypeVars);
|
||||
|
||||
case SyntaxKind.VarTypeExpression:
|
||||
{
|
||||
const scheme = this.lookup(node.name.text);
|
||||
|
@ -1436,21 +1430,20 @@ export class Checker {
|
|||
|
||||
case SyntaxKind.AppTypeExpression:
|
||||
{
|
||||
const operatorType = this.inferTypeExpression(node.operator);
|
||||
const argTypes = [];
|
||||
for (const argTypeExpr of node.args) {
|
||||
argTypes.push(this.inferTypeExpression(argTypeExpr));
|
||||
}
|
||||
return TApp.build(operatorType, argTypes);
|
||||
const argTypes = node.args.map(arg => this.inferTypeExpression(arg, introduceTypeVars));
|
||||
return TApp.build([
|
||||
...argTypes,
|
||||
this.inferTypeExpression(node.operator, introduceTypeVars),
|
||||
]);
|
||||
}
|
||||
|
||||
case SyntaxKind.ArrowTypeExpression:
|
||||
{
|
||||
const paramTypes = [];
|
||||
for (const paramTypeExpr of node.paramTypeExprs) {
|
||||
paramTypes.push(this.inferTypeExpression(paramTypeExpr));
|
||||
paramTypes.push(this.inferTypeExpression(paramTypeExpr, introduceTypeVars));
|
||||
}
|
||||
const returnType = this.inferTypeExpression(node.returnTypeExpr);
|
||||
const returnType = this.inferTypeExpression(node.returnTypeExpr, introduceTypeVars);
|
||||
return new TArrow(paramTypes, returnType, node);
|
||||
}
|
||||
|
||||
|
@ -1751,7 +1744,44 @@ export class Checker {
|
|||
|
||||
case SyntaxKind.EnumDeclaration:
|
||||
{
|
||||
// TODO complete this
|
||||
const env = node.typeEnv = new TypeEnv(parentEnv);
|
||||
const constraints = new ConstraintSet();
|
||||
const typeVars = new TVSet();
|
||||
const context: InferContext = {
|
||||
typeVars,
|
||||
env,
|
||||
constraints,
|
||||
returnType: null,
|
||||
}
|
||||
this.pushContext(context);
|
||||
const kindArgs = [];
|
||||
for (const varExpr of node.varExps) {
|
||||
const kindArg = this.createTypeVar();
|
||||
env.add(varExpr.text, new Forall([], [], kindArg));
|
||||
kindArgs.push(kindArg);
|
||||
}
|
||||
let elementTypes: Type[] = [];
|
||||
const type = new TVariant(node, [], [], node);
|
||||
if (node.members !== null) {
|
||||
for (const member of node.members) {
|
||||
let elementType;
|
||||
switch (member.kind) {
|
||||
case SyntaxKind.EnumDeclarationTupleElement:
|
||||
{
|
||||
const argTypes = member.elements.map(el => this.inferTypeExpression(el));
|
||||
elementType = new TArrow(argTypes, TApp.build([ ...kindArgs, type ]));
|
||||
parentEnv.add(member.name.text, new Forall([], [], elementType));
|
||||
break;
|
||||
}
|
||||
// TODO
|
||||
default:
|
||||
throw new Error(`Unexpected ${member}`);
|
||||
}
|
||||
elementTypes.push(elementType);
|
||||
}
|
||||
}
|
||||
this.popContext(context);
|
||||
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type));
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -2042,48 +2072,15 @@ export class Checker {
|
|||
// constraint.dump();
|
||||
const unify = (left: Type, right: Type): boolean => {
|
||||
|
||||
const resolveType = (type: Type): Type => {
|
||||
const find = (type: Type): Type => {
|
||||
while (type.kind === TypeKind.Var && solution.has(type)) {
|
||||
type = solution.get(type)!;
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
const simplifyType = (type: Type): Type => {
|
||||
|
||||
type = resolveType(type);
|
||||
|
||||
if (type.kind === TypeKind.App) {
|
||||
const stack = [];
|
||||
let i = 0;
|
||||
let operatorType: Type = type;
|
||||
do {
|
||||
operatorType = resolveType(operatorType.operatorType);
|
||||
} while (operatorType.kind === TypeKind.App);
|
||||
assert(isKindedType(operatorType));
|
||||
let curr: Type = resolveType(type);
|
||||
for (;;) {
|
||||
while (curr.kind === TypeKind.App) {
|
||||
stack.push(curr);
|
||||
curr = resolveType(curr.operatorType);
|
||||
}
|
||||
if (curr !== operatorType) {
|
||||
assert(i < operatorType.typeVars!.length);
|
||||
unify(operatorType.typeVars![i++], curr);
|
||||
}
|
||||
if (stack.length === 0) {
|
||||
break;
|
||||
}
|
||||
const next = stack.pop()!;
|
||||
curr = resolveType(next.argType);
|
||||
}
|
||||
return operatorType;
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
left = simplifyType(left);
|
||||
right = simplifyType(right);
|
||||
left = find(left);
|
||||
right = find(right);
|
||||
|
||||
if (left.kind === TypeKind.Var) {
|
||||
if (right.hasTypeVar(left)) {
|
||||
|
@ -2145,23 +2142,6 @@ export class Checker {
|
|||
}
|
||||
}
|
||||
|
||||
// if (left.kind === TypeKind.App && right.kind === TypeKind.App) {
|
||||
// let leftElements = [...left.getSequence()];
|
||||
// let rightElements = [...right.getSequence()];
|
||||
// if (leftElements.length !== rightElements.length) {
|
||||
// this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node));
|
||||
// return false;
|
||||
// }
|
||||
// const count = leftElements.length;
|
||||
// let success = true;
|
||||
// for (let i = 0; i < count; i++) {
|
||||
// if (!unify(leftElements[i], rightElements[i])) {
|
||||
// success = false;
|
||||
// }
|
||||
// }
|
||||
// return success;
|
||||
// }
|
||||
|
||||
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) {
|
||||
let success = false;
|
||||
// This works like an ordinary union-find algorithm where an additional
|
||||
|
@ -2191,6 +2171,19 @@ export class Checker {
|
|||
return success;
|
||||
}
|
||||
|
||||
if (left.kind === TypeKind.Variant && right.kind === TypeKind.Variant) {
|
||||
if (left.decl !== right.decl) {
|
||||
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (left.kind === TypeKind.App && right.kind === TypeKind.App) {
|
||||
return unify(left.left, right.left)
|
||||
&& unify(left.right, right.right);
|
||||
}
|
||||
|
||||
if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) {
|
||||
if (left.decl !== right.decl) {
|
||||
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
|
||||
|
@ -2218,13 +2211,6 @@ export class Checker {
|
|||
return success;
|
||||
}
|
||||
|
||||
// while (left.kind === TypeKind.App) {
|
||||
// left = left.operatorType;
|
||||
// }
|
||||
// while (right.kind === TypeKind.App) {
|
||||
// right = right.operatorType;
|
||||
// }
|
||||
|
||||
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
|
||||
let success = true;
|
||||
if (right.fields === undefined) {
|
||||
|
|
|
@ -1758,6 +1758,8 @@ export class EnumDeclaration extends SyntaxBase {
|
|||
|
||||
public readonly kind = SyntaxKind.EnumDeclaration;
|
||||
|
||||
public typeEnv?: TypeEnv;
|
||||
|
||||
public constructor(
|
||||
public pubKeyword: PubKeyword | null,
|
||||
public enumKeyword: EnumKeyword,
|
||||
|
@ -1810,6 +1812,8 @@ export class StructDeclaration extends SyntaxBase {
|
|||
|
||||
public readonly kind = SyntaxKind.StructDeclaration;
|
||||
|
||||
public typeEnv?: TypeEnv;
|
||||
|
||||
public constructor(
|
||||
public pubKeyword: PubKeyword | null,
|
||||
public structKeyword: StructKeyword,
|
||||
|
@ -1936,6 +1940,8 @@ export class TypeDeclaration extends SyntaxBase {
|
|||
|
||||
public readonly kind = SyntaxKind.TypeDeclaration;
|
||||
|
||||
public typeEnv?: TypeEnv;
|
||||
|
||||
public constructor(
|
||||
public pubKeyword: PubKeyword | null,
|
||||
public typeKeyword: TypeKeyword,
|
||||
|
|
|
@ -204,14 +204,6 @@ export function describeType(type: Type): string {
|
|||
case TypeKind.Record:
|
||||
{
|
||||
return type.decl.name.text;
|
||||
// let out = type.decl.name.text + ' { ';
|
||||
// let first = true;
|
||||
// for (const [fieldName, fieldType] of type.fields) {
|
||||
// if (first) first = false;
|
||||
// else out += ', ';
|
||||
// out += fieldName + ': ' + describeType(fieldType);
|
||||
// }
|
||||
// return out + ' }';
|
||||
}
|
||||
case TypeKind.Labeled:
|
||||
{
|
||||
|
@ -220,7 +212,7 @@ export function describeType(type: Type): string {
|
|||
}
|
||||
case TypeKind.App:
|
||||
{
|
||||
return describeType(type.operatorType) + ' ' + describeType(type.argType);
|
||||
return describeType(type.right) + ' ' + describeType(type.left);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue