Many fixes and add better support for enum-declarations

This commit is contained in:
Sam Vervaeck 2022-09-14 22:34:53 +02:00
parent 4cc2b23109
commit c5fe5004b6
3 changed files with 155 additions and 171 deletions

View file

@ -1,4 +1,6 @@
import {
EnumDeclaration,
EnumDeclarationElement,
Expression,
LetDeclaration,
Pattern,
@ -280,7 +282,7 @@ export class TRecord extends TypeBase {
public constructor(
public decl: Syntax,
public typeVars: TVar[],
public kindArgs: TVar[],
public fields: Map<string, Type>,
public node: Syntax | null = null,
) {
@ -296,7 +298,7 @@ export class TRecord extends TypeBase {
public shallowClone(): TRecord {
return new TRecord(
this.decl,
this.typeVars,
this.kindArgs,
this.fields,
this.node
);
@ -305,7 +307,7 @@ export class TRecord extends TypeBase {
public substitute(sub: TVSub): Type {
let changed = false;
const newTypeVars = [];
for (const typeVar of this.typeVars) {
for (const typeVar of this.kindArgs) {
const newTypeVar = typeVar.substitute(sub);
assert(newTypeVar.kind === TypeKind.Var);
if (newTypeVar !== typeVar) {
@ -331,59 +333,42 @@ export class TApp extends TypeBase {
public readonly kind = TypeKind.App;
public constructor(
public operatorType: Type,
public argType: Type,
public left: Type,
public right: Type,
public node: Syntax | null = null
) {
super(node);
}
public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): Type {
if (argTypes.length === 0) {
return operatorType;
}
let count = argTypes.length;
let result = argTypes[count-1];
for (let i = count-2; i >= 0; i--) {
result = new TApp(argTypes[i], result, node);
}
return new TApp(operatorType, result, node);
}
public *getSequence(): Iterable<Type> {
if (this.operatorType.kind === TypeKind.App) {
yield* this.operatorType.getSequence();
} else {
yield this.operatorType;
}
if (this.argType.kind === TypeKind.App) {
yield* this.argType.getSequence();
} else {
yield this.argType;
public static build(types: Type[]): Type {
let result = types[0];
for (let i = 1; i < types.length; i++) {
result = new TApp(result, types[i]);
}
return result;
}
public *getTypeVars(): Iterable<TVar> {
yield* this.operatorType.getTypeVars();
yield* this.argType.getTypeVars();
yield* this.left.getTypeVars();
yield* this.right.getTypeVars();
}
public shallowClone() {
return new TApp(
this.operatorType,
this.argType,
this.left,
this.right,
this.node
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newOperatorType = this.operatorType.substitute(sub);
if (newOperatorType !== this.operatorType) {
const newOperatorType = this.left.substitute(sub);
if (newOperatorType !== this.left) {
changed = true;
}
const newArgType = this.argType.substitute(sub);
if (newArgType !== this.argType) {
const newArgType = this.right.substitute(sub);
if (newArgType !== this.right) {
changed = true;
}
return changed ? new TApp(newOperatorType, newArgType, this.node) : this;
@ -396,7 +381,8 @@ export class TVariant extends TypeBase {
public readonly kind = TypeKind.Variant;
public constructor(
public typeVars: TVar[],
public decl: Syntax,
public kindArgs: Type[],
public elementTypes: Type[],
public node: Syntax | null = null,
) {
@ -411,7 +397,8 @@ export class TVariant extends TypeBase {
public shallowClone(): Type {
return new TVariant(
this.typeVars,
this.decl,
this.kindArgs,
this.elementTypes,
this.node,
);
@ -420,10 +407,10 @@ export class TVariant extends TypeBase {
public substitute(sub: TVSub): Type {
let changed = false;
const newTypeVars = [];
for (const typeVar of this.typeVars) {
const newTypeVar = typeVar.substitute(sub);
for (const kindArg of this.kindArgs) {
const newTypeVar = kindArg.substitute(sub);
assert(newTypeVar.kind === TypeKind.Var);
if (newTypeVar !== typeVar) {
if (newTypeVar !== kindArg) {
changed = true;
}
newTypeVars.push(newTypeVar);
@ -436,7 +423,7 @@ export class TVariant extends TypeBase {
}
newElementTypes.push(newElementType);
}
return changed ? new TVariant(newTypeVars, newElementTypes, this.node) : this;
return changed ? new TVariant(this.decl, newTypeVars, newElementTypes, this.node) : this;
}
}
@ -904,11 +891,9 @@ export class Checker {
}
case SyntaxKind.AppTypeExpression:
{
let operator = this.inferKindFromTypeExpression(node.operator, env);
const args = node.args.map(arg => this.inferKindFromTypeExpression(arg, env));
let result = operator;
for (const arg of args) {
result = this.applyKind(result, arg, node);
let result = this.inferKindFromTypeExpression(node.operator, env);;
for (const arg of node.args) {
result = this.applyKind(result, this.inferKindFromTypeExpression(arg, env), node);
}
return result;
}
@ -1032,6 +1017,19 @@ export class Checker {
}
break;
}
case SyntaxKind.LetDeclaration:
{
if (node.typeAssert !== null) {
this.unifyKind(this.inferKindFromTypeExpression(node.typeAssert.typeExpression, env), new KStar(), node.typeAssert.typeExpression);
}
if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) {
for (const element of node.body.elements) {
// TODO fork `env` to support local type declarations
this.inferKind(element, env);
}
}
break;
}
}
}
@ -1069,34 +1067,6 @@ export class Checker {
if (a.type === KindType.Arrow && b.type === KindType.Arrow) {
return this.unifyKind(a.left, b.left, node)
|| this.unifyKind(a.right, b.right, node);
// let success = true;
// const leftStack = [];
// const rightStack = [];
// let leftCurr: Kind = a;
// let rightCurr: Kind = b;
// for (;;) {
// while (leftCurr.type === KindType.Arrow) {
// leftStack.push(leftCurr);
// leftCurr = find(leftCurr.left);
// }
// while (rightCurr.type === KindType.Arrow) {
// rightStack.push(rightCurr);
// rightCurr = find(rightCurr.left);
// }
// if (!this.unifyKind(leftCurr, rightCurr, node)) {
// success = false;
// }
// if (leftStack.length === 0 || rightStack.length === 0) {
// if (leftStack.length > 0 || rightStack.length > 0) {
// this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node));
// success = false;
// }
// break;
// }
// rightCurr = find(rightStack.pop()!.right);
// leftCurr = find(leftStack.pop()!.right);
// }
// return success;
}
this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node));
@ -1215,6 +1185,27 @@ export class Checker {
}
private buildVariantType(decl: EnumDeclarationElement, type: Type): Type {
const enumDecl = decl.parent as EnumDeclaration;
const kindArgs = [];
for (const _ of enumDecl.varExps) {
kindArgs.push(this.createTypeVar());
}
const variantTypes = [];
if (enumDecl.members !== null) {
for (const member of enumDecl.members) {
let variantType;
if (member === decl) {
variantType = type;
} else {
variantType = this.createTypeVar();
}
variantTypes.push(variantType);
}
}
return TApp.build([ ...kindArgs, new TVariant(enumDecl, [], []) ]);
}
public inferExpression(node: Expression): Type {
switch (node.kind) {
@ -1293,18 +1284,27 @@ export class Checker {
case SyntaxKind.NamedTupleExpression:
{
// TODO Only lookup constructors and skip other bindings
const scheme = this.lookup(node.name.text);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
}
const type = this.instantiate(scheme, node.name);
assert(type.kind === TypeKind.Con);
const argTypes = [];
for (const element of node.elements) {
argTypes.push(this.inferExpression(element));
}
return new TCon(type.id, argTypes, type.displayName, node);
const operatorType = this.instantiate(scheme, node.name);
const argTypes = node.elements.map(el => this.inferExpression(el));
const retType = this.createTypeVar();
this.addConstraint(
new CEqual(
new TArrow(
argTypes,
retType,
node,
),
operatorType,
node
)
);
return retType;
}
case SyntaxKind.StructExpression:
@ -1350,19 +1350,9 @@ export class Checker {
throw new Error(`Unexpected ${member}`);
}
}
let type = new TRecord(decl, argTypes, fields, node);
let type: Type = new TRecord(decl, argTypes, fields, node);
if (decl.kind === SyntaxKind.EnumDeclarationStructElement) {
const elementTypes = [];
for (const element of decl.parent!.elements) {
let elementType;
if (element === decl) {
elementType = type;
} else {
elementType = this.createTypeVar();
}
elementTypes.push(elementType);
}
type = new TVariant(typeVars, elementTypes);
type = this.buildVariantType(decl, type);
}
this.addConstraint(
new CEqual(
@ -1414,10 +1404,14 @@ export class Checker {
return this.createTypeVar();
}
const type = this.instantiate(scheme, node.name);
// FIXME it is not guaranteed that `type` is copied, so the original type might get mutated
type.node = node;
return type;
}
case SyntaxKind.NestedTypeExpression:
return this.inferTypeExpression(node.typeExpr, introduceTypeVars);
case SyntaxKind.VarTypeExpression:
{
const scheme = this.lookup(node.name.text);
@ -1436,21 +1430,20 @@ export class Checker {
case SyntaxKind.AppTypeExpression:
{
const operatorType = this.inferTypeExpression(node.operator);
const argTypes = [];
for (const argTypeExpr of node.args) {
argTypes.push(this.inferTypeExpression(argTypeExpr));
}
return TApp.build(operatorType, argTypes);
const argTypes = node.args.map(arg => this.inferTypeExpression(arg, introduceTypeVars));
return TApp.build([
...argTypes,
this.inferTypeExpression(node.operator, introduceTypeVars),
]);
}
case SyntaxKind.ArrowTypeExpression:
{
const paramTypes = [];
for (const paramTypeExpr of node.paramTypeExprs) {
paramTypes.push(this.inferTypeExpression(paramTypeExpr));
paramTypes.push(this.inferTypeExpression(paramTypeExpr, introduceTypeVars));
}
const returnType = this.inferTypeExpression(node.returnTypeExpr);
const returnType = this.inferTypeExpression(node.returnTypeExpr, introduceTypeVars);
return new TArrow(paramTypes, returnType, node);
}
@ -1751,7 +1744,44 @@ export class Checker {
case SyntaxKind.EnumDeclaration:
{
// TODO complete this
const env = node.typeEnv = new TypeEnv(parentEnv);
const constraints = new ConstraintSet();
const typeVars = new TVSet();
const context: InferContext = {
typeVars,
env,
constraints,
returnType: null,
}
this.pushContext(context);
const kindArgs = [];
for (const varExpr of node.varExps) {
const kindArg = this.createTypeVar();
env.add(varExpr.text, new Forall([], [], kindArg));
kindArgs.push(kindArg);
}
let elementTypes: Type[] = [];
const type = new TVariant(node, [], [], node);
if (node.members !== null) {
for (const member of node.members) {
let elementType;
switch (member.kind) {
case SyntaxKind.EnumDeclarationTupleElement:
{
const argTypes = member.elements.map(el => this.inferTypeExpression(el));
elementType = new TArrow(argTypes, TApp.build([ ...kindArgs, type ]));
parentEnv.add(member.name.text, new Forall([], [], elementType));
break;
}
// TODO
default:
throw new Error(`Unexpected ${member}`);
}
elementTypes.push(elementType);
}
}
this.popContext(context);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type));
break;
}
@ -2042,48 +2072,15 @@ export class Checker {
// constraint.dump();
const unify = (left: Type, right: Type): boolean => {
const resolveType = (type: Type): Type => {
const find = (type: Type): Type => {
while (type.kind === TypeKind.Var && solution.has(type)) {
type = solution.get(type)!;
}
return type;
}
const simplifyType = (type: Type): Type => {
type = resolveType(type);
if (type.kind === TypeKind.App) {
const stack = [];
let i = 0;
let operatorType: Type = type;
do {
operatorType = resolveType(operatorType.operatorType);
} while (operatorType.kind === TypeKind.App);
assert(isKindedType(operatorType));
let curr: Type = resolveType(type);
for (;;) {
while (curr.kind === TypeKind.App) {
stack.push(curr);
curr = resolveType(curr.operatorType);
}
if (curr !== operatorType) {
assert(i < operatorType.typeVars!.length);
unify(operatorType.typeVars![i++], curr);
}
if (stack.length === 0) {
break;
}
const next = stack.pop()!;
curr = resolveType(next.argType);
}
return operatorType;
}
return type;
}
left = simplifyType(left);
right = simplifyType(right);
left = find(left);
right = find(right);
if (left.kind === TypeKind.Var) {
if (right.hasTypeVar(left)) {
@ -2145,23 +2142,6 @@ export class Checker {
}
}
// if (left.kind === TypeKind.App && right.kind === TypeKind.App) {
// let leftElements = [...left.getSequence()];
// let rightElements = [...right.getSequence()];
// if (leftElements.length !== rightElements.length) {
// this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node));
// return false;
// }
// const count = leftElements.length;
// let success = true;
// for (let i = 0; i < count; i++) {
// if (!unify(leftElements[i], rightElements[i])) {
// success = false;
// }
// }
// return success;
// }
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) {
let success = false;
// This works like an ordinary union-find algorithm where an additional
@ -2191,6 +2171,19 @@ export class Checker {
return success;
}
if (left.kind === TypeKind.Variant && right.kind === TypeKind.Variant) {
if (left.decl !== right.decl) {
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
return false;
}
return true;
}
if (left.kind === TypeKind.App && right.kind === TypeKind.App) {
return unify(left.left, right.left)
&& unify(left.right, right.right);
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) {
if (left.decl !== right.decl) {
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
@ -2218,13 +2211,6 @@ export class Checker {
return success;
}
// while (left.kind === TypeKind.App) {
// left = left.operatorType;
// }
// while (right.kind === TypeKind.App) {
// right = right.operatorType;
// }
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let success = true;
if (right.fields === undefined) {

View file

@ -1758,6 +1758,8 @@ export class EnumDeclaration extends SyntaxBase {
public readonly kind = SyntaxKind.EnumDeclaration;
public typeEnv?: TypeEnv;
public constructor(
public pubKeyword: PubKeyword | null,
public enumKeyword: EnumKeyword,
@ -1810,6 +1812,8 @@ export class StructDeclaration extends SyntaxBase {
public readonly kind = SyntaxKind.StructDeclaration;
public typeEnv?: TypeEnv;
public constructor(
public pubKeyword: PubKeyword | null,
public structKeyword: StructKeyword,
@ -1936,6 +1940,8 @@ export class TypeDeclaration extends SyntaxBase {
public readonly kind = SyntaxKind.TypeDeclaration;
public typeEnv?: TypeEnv;
public constructor(
public pubKeyword: PubKeyword | null,
public typeKeyword: TypeKeyword,

View file

@ -204,14 +204,6 @@ export function describeType(type: Type): string {
case TypeKind.Record:
{
return type.decl.name.text;
// let out = type.decl.name.text + ' { ';
// let first = true;
// for (const [fieldName, fieldType] of type.fields) {
// if (first) first = false;
// else out += ', ';
// out += fieldName + ': ' + describeType(fieldType);
// }
// return out + ' }';
}
case TypeKind.Labeled:
{
@ -220,7 +212,7 @@ export function describeType(type: Type): string {
}
case TypeKind.App:
{
return describeType(type.operatorType) + ' ' + describeType(type.argType);
return describeType(type.right) + ' ' + describeType(type.left);
}
}
}