diff --git a/src/checker.ts b/src/checker.ts index 86f739d26..d806e404c 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -1,4 +1,6 @@ import { + EnumDeclaration, + EnumDeclarationElement, Expression, LetDeclaration, Pattern, @@ -280,7 +282,7 @@ export class TRecord extends TypeBase { public constructor( public decl: Syntax, - public typeVars: TVar[], + public kindArgs: TVar[], public fields: Map, public node: Syntax | null = null, ) { @@ -296,7 +298,7 @@ export class TRecord extends TypeBase { public shallowClone(): TRecord { return new TRecord( this.decl, - this.typeVars, + this.kindArgs, this.fields, this.node ); @@ -305,7 +307,7 @@ export class TRecord extends TypeBase { public substitute(sub: TVSub): Type { let changed = false; const newTypeVars = []; - for (const typeVar of this.typeVars) { + for (const typeVar of this.kindArgs) { const newTypeVar = typeVar.substitute(sub); assert(newTypeVar.kind === TypeKind.Var); if (newTypeVar !== typeVar) { @@ -331,59 +333,42 @@ export class TApp extends TypeBase { public readonly kind = TypeKind.App; public constructor( - public operatorType: Type, - public argType: Type, + public left: Type, + public right: Type, public node: Syntax | null = null ) { super(node); } - public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): Type { - if (argTypes.length === 0) { - return operatorType; - } - let count = argTypes.length; - let result = argTypes[count-1]; - for (let i = count-2; i >= 0; i--) { - result = new TApp(argTypes[i], result, node); - } - return new TApp(operatorType, result, node); - } - - public *getSequence(): Iterable { - if (this.operatorType.kind === TypeKind.App) { - yield* this.operatorType.getSequence(); - } else { - yield this.operatorType; - } - if (this.argType.kind === TypeKind.App) { - yield* this.argType.getSequence(); - } else { - yield this.argType; + public static build(types: Type[]): Type { + let result = types[0]; + for (let i = 1; i < types.length; i++) { + result = new TApp(result, types[i]); } + return result; } public *getTypeVars(): Iterable { - yield* this.operatorType.getTypeVars(); - yield* this.argType.getTypeVars(); + yield* this.left.getTypeVars(); + yield* this.right.getTypeVars(); } public shallowClone() { return new TApp( - this.operatorType, - this.argType, + this.left, + this.right, this.node ); } public substitute(sub: TVSub): Type { let changed = false; - const newOperatorType = this.operatorType.substitute(sub); - if (newOperatorType !== this.operatorType) { + const newOperatorType = this.left.substitute(sub); + if (newOperatorType !== this.left) { changed = true; } - const newArgType = this.argType.substitute(sub); - if (newArgType !== this.argType) { + const newArgType = this.right.substitute(sub); + if (newArgType !== this.right) { changed = true; } return changed ? new TApp(newOperatorType, newArgType, this.node) : this; @@ -396,7 +381,8 @@ export class TVariant extends TypeBase { public readonly kind = TypeKind.Variant; public constructor( - public typeVars: TVar[], + public decl: Syntax, + public kindArgs: Type[], public elementTypes: Type[], public node: Syntax | null = null, ) { @@ -411,7 +397,8 @@ export class TVariant extends TypeBase { public shallowClone(): Type { return new TVariant( - this.typeVars, + this.decl, + this.kindArgs, this.elementTypes, this.node, ); @@ -420,10 +407,10 @@ export class TVariant extends TypeBase { public substitute(sub: TVSub): Type { let changed = false; const newTypeVars = []; - for (const typeVar of this.typeVars) { - const newTypeVar = typeVar.substitute(sub); + for (const kindArg of this.kindArgs) { + const newTypeVar = kindArg.substitute(sub); assert(newTypeVar.kind === TypeKind.Var); - if (newTypeVar !== typeVar) { + if (newTypeVar !== kindArg) { changed = true; } newTypeVars.push(newTypeVar); @@ -436,7 +423,7 @@ export class TVariant extends TypeBase { } newElementTypes.push(newElementType); } - return changed ? new TVariant(newTypeVars, newElementTypes, this.node) : this; + return changed ? new TVariant(this.decl, newTypeVars, newElementTypes, this.node) : this; } } @@ -904,11 +891,9 @@ export class Checker { } case SyntaxKind.AppTypeExpression: { - let operator = this.inferKindFromTypeExpression(node.operator, env); - const args = node.args.map(arg => this.inferKindFromTypeExpression(arg, env)); - let result = operator; - for (const arg of args) { - result = this.applyKind(result, arg, node); + let result = this.inferKindFromTypeExpression(node.operator, env);; + for (const arg of node.args) { + result = this.applyKind(result, this.inferKindFromTypeExpression(arg, env), node); } return result; } @@ -1032,6 +1017,19 @@ export class Checker { } break; } + case SyntaxKind.LetDeclaration: + { + if (node.typeAssert !== null) { + this.unifyKind(this.inferKindFromTypeExpression(node.typeAssert.typeExpression, env), new KStar(), node.typeAssert.typeExpression); + } + if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) { + for (const element of node.body.elements) { + // TODO fork `env` to support local type declarations + this.inferKind(element, env); + } + } + break; + } } } @@ -1069,34 +1067,6 @@ export class Checker { if (a.type === KindType.Arrow && b.type === KindType.Arrow) { return this.unifyKind(a.left, b.left, node) || this.unifyKind(a.right, b.right, node); - // let success = true; - // const leftStack = []; - // const rightStack = []; - // let leftCurr: Kind = a; - // let rightCurr: Kind = b; - // for (;;) { - // while (leftCurr.type === KindType.Arrow) { - // leftStack.push(leftCurr); - // leftCurr = find(leftCurr.left); - // } - // while (rightCurr.type === KindType.Arrow) { - // rightStack.push(rightCurr); - // rightCurr = find(rightCurr.left); - // } - // if (!this.unifyKind(leftCurr, rightCurr, node)) { - // success = false; - // } - // if (leftStack.length === 0 || rightStack.length === 0) { - // if (leftStack.length > 0 || rightStack.length > 0) { - // this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node)); - // success = false; - // } - // break; - // } - // rightCurr = find(rightStack.pop()!.right); - // leftCurr = find(leftStack.pop()!.right); - // } - // return success; } this.diagnostics.add(new KindMismatchDiagnostic(solve(a), solve(b), node)); @@ -1215,6 +1185,27 @@ export class Checker { } + private buildVariantType(decl: EnumDeclarationElement, type: Type): Type { + const enumDecl = decl.parent as EnumDeclaration; + const kindArgs = []; + for (const _ of enumDecl.varExps) { + kindArgs.push(this.createTypeVar()); + } + const variantTypes = []; + if (enumDecl.members !== null) { + for (const member of enumDecl.members) { + let variantType; + if (member === decl) { + variantType = type; + } else { + variantType = this.createTypeVar(); + } + variantTypes.push(variantType); + } + } + return TApp.build([ ...kindArgs, new TVariant(enumDecl, [], []) ]); + } + public inferExpression(node: Expression): Type { switch (node.kind) { @@ -1293,18 +1284,27 @@ export class Checker { case SyntaxKind.NamedTupleExpression: { + // TODO Only lookup constructors and skip other bindings const scheme = this.lookup(node.name.text); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); return this.createTypeVar(); } - const type = this.instantiate(scheme, node.name); - assert(type.kind === TypeKind.Con); - const argTypes = []; - for (const element of node.elements) { - argTypes.push(this.inferExpression(element)); - } - return new TCon(type.id, argTypes, type.displayName, node); + const operatorType = this.instantiate(scheme, node.name); + const argTypes = node.elements.map(el => this.inferExpression(el)); + const retType = this.createTypeVar(); + this.addConstraint( + new CEqual( + new TArrow( + argTypes, + retType, + node, + ), + operatorType, + node + ) + ); + return retType; } case SyntaxKind.StructExpression: @@ -1350,19 +1350,9 @@ export class Checker { throw new Error(`Unexpected ${member}`); } } - let type = new TRecord(decl, argTypes, fields, node); + let type: Type = new TRecord(decl, argTypes, fields, node); if (decl.kind === SyntaxKind.EnumDeclarationStructElement) { - const elementTypes = []; - for (const element of decl.parent!.elements) { - let elementType; - if (element === decl) { - elementType = type; - } else { - elementType = this.createTypeVar(); - } - elementTypes.push(elementType); - } - type = new TVariant(typeVars, elementTypes); + type = this.buildVariantType(decl, type); } this.addConstraint( new CEqual( @@ -1414,10 +1404,14 @@ export class Checker { return this.createTypeVar(); } const type = this.instantiate(scheme, node.name); + // FIXME it is not guaranteed that `type` is copied, so the original type might get mutated type.node = node; return type; } + case SyntaxKind.NestedTypeExpression: + return this.inferTypeExpression(node.typeExpr, introduceTypeVars); + case SyntaxKind.VarTypeExpression: { const scheme = this.lookup(node.name.text); @@ -1436,21 +1430,20 @@ export class Checker { case SyntaxKind.AppTypeExpression: { - const operatorType = this.inferTypeExpression(node.operator); - const argTypes = []; - for (const argTypeExpr of node.args) { - argTypes.push(this.inferTypeExpression(argTypeExpr)); - } - return TApp.build(operatorType, argTypes); + const argTypes = node.args.map(arg => this.inferTypeExpression(arg, introduceTypeVars)); + return TApp.build([ + ...argTypes, + this.inferTypeExpression(node.operator, introduceTypeVars), + ]); } case SyntaxKind.ArrowTypeExpression: { const paramTypes = []; for (const paramTypeExpr of node.paramTypeExprs) { - paramTypes.push(this.inferTypeExpression(paramTypeExpr)); + paramTypes.push(this.inferTypeExpression(paramTypeExpr, introduceTypeVars)); } - const returnType = this.inferTypeExpression(node.returnTypeExpr); + const returnType = this.inferTypeExpression(node.returnTypeExpr, introduceTypeVars); return new TArrow(paramTypes, returnType, node); } @@ -1751,7 +1744,44 @@ export class Checker { case SyntaxKind.EnumDeclaration: { - // TODO complete this + const env = node.typeEnv = new TypeEnv(parentEnv); + const constraints = new ConstraintSet(); + const typeVars = new TVSet(); + const context: InferContext = { + typeVars, + env, + constraints, + returnType: null, + } + this.pushContext(context); + const kindArgs = []; + for (const varExpr of node.varExps) { + const kindArg = this.createTypeVar(); + env.add(varExpr.text, new Forall([], [], kindArg)); + kindArgs.push(kindArg); + } + let elementTypes: Type[] = []; + const type = new TVariant(node, [], [], node); + if (node.members !== null) { + for (const member of node.members) { + let elementType; + switch (member.kind) { + case SyntaxKind.EnumDeclarationTupleElement: + { + const argTypes = member.elements.map(el => this.inferTypeExpression(el)); + elementType = new TArrow(argTypes, TApp.build([ ...kindArgs, type ])); + parentEnv.add(member.name.text, new Forall([], [], elementType)); + break; + } + // TODO + default: + throw new Error(`Unexpected ${member}`); + } + elementTypes.push(elementType); + } + } + this.popContext(context); + parentEnv.add(node.name.text, new Forall(typeVars, constraints, type)); break; } @@ -2042,48 +2072,15 @@ export class Checker { // constraint.dump(); const unify = (left: Type, right: Type): boolean => { - const resolveType = (type: Type): Type => { + const find = (type: Type): Type => { while (type.kind === TypeKind.Var && solution.has(type)) { type = solution.get(type)!; } return type; } - const simplifyType = (type: Type): Type => { - - type = resolveType(type); - - if (type.kind === TypeKind.App) { - const stack = []; - let i = 0; - let operatorType: Type = type; - do { - operatorType = resolveType(operatorType.operatorType); - } while (operatorType.kind === TypeKind.App); - assert(isKindedType(operatorType)); - let curr: Type = resolveType(type); - for (;;) { - while (curr.kind === TypeKind.App) { - stack.push(curr); - curr = resolveType(curr.operatorType); - } - if (curr !== operatorType) { - assert(i < operatorType.typeVars!.length); - unify(operatorType.typeVars![i++], curr); - } - if (stack.length === 0) { - break; - } - const next = stack.pop()!; - curr = resolveType(next.argType); - } - return operatorType; - } - return type; - } - - left = simplifyType(left); - right = simplifyType(right); + left = find(left); + right = find(right); if (left.kind === TypeKind.Var) { if (right.hasTypeVar(left)) { @@ -2145,23 +2142,6 @@ export class Checker { } } - // if (left.kind === TypeKind.App && right.kind === TypeKind.App) { - // let leftElements = [...left.getSequence()]; - // let rightElements = [...right.getSequence()]; - // if (leftElements.length !== rightElements.length) { - // this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node)); - // return false; - // } - // const count = leftElements.length; - // let success = true; - // for (let i = 0; i < count; i++) { - // if (!unify(leftElements[i], rightElements[i])) { - // success = false; - // } - // } - // return success; - // } - if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) { let success = false; // This works like an ordinary union-find algorithm where an additional @@ -2191,6 +2171,19 @@ export class Checker { return success; } + if (left.kind === TypeKind.Variant && right.kind === TypeKind.Variant) { + if (left.decl !== right.decl) { + this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()])); + return false; + } + return true; + } + + if (left.kind === TypeKind.App && right.kind === TypeKind.App) { + return unify(left.left, right.left) + && unify(left.right, right.right); + } + if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) { if (left.decl !== right.decl) { this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()])); @@ -2218,13 +2211,6 @@ export class Checker { return success; } - // while (left.kind === TypeKind.App) { - // left = left.operatorType; - // } - // while (right.kind === TypeKind.App) { - // right = right.operatorType; - // } - if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) { let success = true; if (right.fields === undefined) { diff --git a/src/cst.ts b/src/cst.ts index 6b14f5a1a..b5a8e54fe 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -1758,6 +1758,8 @@ export class EnumDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.EnumDeclaration; + public typeEnv?: TypeEnv; + public constructor( public pubKeyword: PubKeyword | null, public enumKeyword: EnumKeyword, @@ -1810,6 +1812,8 @@ export class StructDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.StructDeclaration; + public typeEnv?: TypeEnv; + public constructor( public pubKeyword: PubKeyword | null, public structKeyword: StructKeyword, @@ -1936,6 +1940,8 @@ export class TypeDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.TypeDeclaration; + public typeEnv?: TypeEnv; + public constructor( public pubKeyword: PubKeyword | null, public typeKeyword: TypeKeyword, diff --git a/src/diagnostics.ts b/src/diagnostics.ts index e745ca5c3..21bc33791 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -204,14 +204,6 @@ export function describeType(type: Type): string { case TypeKind.Record: { return type.decl.name.text; - // let out = type.decl.name.text + ' { '; - // let first = true; - // for (const [fieldName, fieldType] of type.fields) { - // if (first) first = false; - // else out += ', '; - // out += fieldName + ': ' + describeType(fieldType); - // } - // return out + ' }'; } case TypeKind.Labeled: { @@ -220,7 +212,7 @@ export function describeType(type: Type): string { } case TypeKind.App: { - return describeType(type.operatorType) + ' ' + describeType(type.argType); + return describeType(type.right) + ' ' + describeType(type.left); } } }