From 2f359107c4b6b0af4829f4cfffd7f6cc967eaa6a Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Thu, 15 Sep 2022 11:49:53 +0200 Subject: [PATCH] Multiple fixes related to the type-checker - Add more tests - Make struct-declarations type-check - Split environment into type bindings and variable bindings - Fix kind inference adding the wrong element to the env --- src/checker.ts | 167 ++++++++++++++++++------------------- src/cst.ts | 24 +++--- src/diagnostics.ts | 2 +- src/test/type-inference.md | 40 +++++++++ src/util.ts | 68 +++++++++++++++ 5 files changed, 199 insertions(+), 102 deletions(-) diff --git a/src/checker.ts b/src/checker.ts index ae43a3f66..8d72f953c 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -22,7 +22,7 @@ import { UnificationFailedDiagnostic, KindMismatchDiagnostic } from "./diagnostics"; -import { assert, isEmpty } from "./util"; +import { assert, isEmpty, MultiMap } from "./util"; import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl" const MAX_TYPE_ERROR_COUNT = 5; @@ -340,10 +340,10 @@ export class TApp extends TypeBase { super(node); } - public static build(types: Type[]): Type { + public static build(types: Type[], node: Syntax | null = null): Type { let result = types[0]; for (let i = 1; i < types.length; i++) { - result = new TApp(result, types[i]); + result = new TApp(result, types[i], node); } return result; } @@ -716,22 +716,23 @@ type Scheme export class TypeEnv { - private mapping = new Map(); + private mapping = new MultiMap(); public constructor(public parent: TypeEnv | null = null) { } - public add(name: string, scheme: Scheme): void { - this.mapping.set(name, scheme); + public add(name: string, scheme: Scheme, kind: Symkind): void { + this.mapping.add(name, [kind, scheme]); } - public lookup(name: string): Scheme | null { + public lookup(name: string, expectedKind: Symkind): Scheme | null { let curr: TypeEnv | null = this; do { - const scheme = curr.mapping.get(name); - if (scheme !== undefined) { - return scheme; + for (const [kind, scheme] of curr.mapping.get(name)) { + if (kind & expectedKind) { + return scheme; + } } curr = curr.parent; } while(curr !== null); @@ -842,9 +843,9 @@ export class Checker { this.contexts.pop(); } - private lookup(name: string): Scheme | null { + private lookup(name: string, kind: Symkind): Scheme | null { const context = this.contexts[this.contexts.length-1]; - return context.env.lookup(name); + return context.env.lookup(name, kind); } private getReturnType(): Type { @@ -871,9 +872,9 @@ export class Checker { return scheme.type.substitute(sub); } - private addBinding(name: string, scheme: Scheme): void { + private addBinding(name: string, scheme: Scheme, kind: Symkind): void { const context = this.contexts[this.contexts.length-1]; - context.env.add(name, scheme); + context.env.add(name, scheme, kind); } private inferKindFromTypeExpression(node: TypeExpression, env: KindEnv): Kind { @@ -959,6 +960,10 @@ export class Checker { } case SyntaxKind.StructDeclaration: + { + env.setNamed(node.name.text, this.createKindVar()); + break; + } case SyntaxKind.EnumDeclaration: { env.setNamed(node.name.text, this.createKindVar()); @@ -1136,7 +1141,7 @@ export class Checker { let type; if (node.pattern.kind === SyntaxKind.WrappedOperator) { type = this.createTypeVar(); - this.addBinding(node.pattern.operator.text, new Forall([], [], type)); + this.addBinding(node.pattern.operator.text, new Forall([], [], type), Symkind.Var); } else { type = this.inferBindings(node.pattern, [], []); } @@ -1185,27 +1190,6 @@ export class Checker { } - private buildVariantType(decl: EnumDeclarationElement, type: Type): Type { - const enumDecl = decl.parent as EnumDeclaration; - const kindArgs = []; - for (const _ of enumDecl.varExps) { - kindArgs.push(this.createTypeVar()); - } - const variantTypes = []; - if (enumDecl.members !== null) { - for (const member of enumDecl.members) { - let variantType; - if (member === decl) { - variantType = type; - } else { - variantType = this.createTypeVar(); - } - variantTypes.push(variantType); - } - } - return TApp.build([ ...kindArgs, new TVariant(enumDecl, [], []) ]); - } - public inferExpression(node: Expression): Type { switch (node.kind) { @@ -1221,7 +1205,7 @@ export class Checker { if (target !== null && target.kind === SyntaxKind.LetDeclaration && target.active) { return target.type!; } - const scheme = this.lookup(node.name.name.text); + const scheme = this.lookup(node.name.name.text, Symkind.Var); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name)); return this.createTypeVar(); @@ -1285,7 +1269,7 @@ export class Checker { case SyntaxKind.NamedTupleExpression: { // TODO Only lookup constructors and skip other bindings - const scheme = this.lookup(node.name.text); + const scheme = this.lookup(node.name.text, Symkind.Var); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); return this.createTypeVar(); @@ -1310,20 +1294,19 @@ export class Checker { case SyntaxKind.StructExpression: { const scope = node.getScope(); - const decl = scope.lookup(node.name.text, Symkind.Constructor); + const decl = scope.lookup(node.name.text, Symkind.Type); if (decl === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); return this.createTypeVar(); } assert(decl.kind === SyntaxKind.StructDeclaration || decl.kind === SyntaxKind.EnumDeclarationStructElement); const scheme = decl.scheme; - const sub = this.createSubstitution(scheme); - const declType = this.instantiate(scheme, node, sub); - const argTypes = []; - for (const typeVar of decl.tvs) { - const newTypeVar = sub.get(typeVar)!; - assert(newTypeVar.kind === TypeKind.Var); - argTypes.push(newTypeVar); + const declType = this.instantiate(scheme, node); + const kindArgs = []; + const varExps = decl.kind === SyntaxKind.StructDeclaration + ? decl.varExps : (decl.parent! as EnumDeclaration).varExps; + for (const _ of varExps) { + kindArgs.push(this.createTypeVar()); } const fields = new Map(); for (const member of node.members) { @@ -1335,7 +1318,7 @@ export class Checker { } case SyntaxKind.PunnedStructExpressionField: { - const scheme = this.lookup(member.name.text); + const scheme = this.lookup(member.name.text, Symkind.Var); let fieldType; if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(member.name.text, member.name)); @@ -1350,9 +1333,10 @@ export class Checker { throw new Error(`Unexpected ${member}`); } } - let type: Type = new TRecord(decl, argTypes, fields, node); + let type: Type = TApp.build([ ...kindArgs, new TRecord(decl, [], fields, node) ]); if (decl.kind === SyntaxKind.EnumDeclarationStructElement) { - type = this.buildVariantType(decl, type); + // TODO + // type = this.buildVariantType(decl, type); } this.addConstraint( new CEqual( @@ -1366,7 +1350,7 @@ export class Checker { case SyntaxKind.InfixExpression: { - const scheme = this.lookup(node.operator.text); + const scheme = this.lookup(node.operator.text, Symkind.Var); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.operator.text, node.operator)); return this.createTypeVar(); @@ -1398,7 +1382,7 @@ export class Checker { case SyntaxKind.ReferenceTypeExpression: { - const scheme = this.lookup(node.name.text); + const scheme = this.lookup(node.name.text, Symkind.Type); if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); return this.createTypeVar(); @@ -1414,13 +1398,13 @@ export class Checker { case SyntaxKind.VarTypeExpression: { - const scheme = this.lookup(node.name.text); + const scheme = this.lookup(node.name.text, Symkind.Type); if (scheme === null) { if (!introduceTypeVars) { this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); } const type = this.createTypeVar(); - this.addBinding(node.name.text, new Forall([], [], type)); + this.addBinding(node.name.text, new Forall([], [], type), Symkind.Type); return type; } assert(isEmpty(scheme.typeVars)); @@ -1460,13 +1444,13 @@ export class Checker { case SyntaxKind.BindPattern: { const type = this.createTypeVar(); - this.addBinding(pattern.name.text, new Forall(typeVars, constraints, type)); + this.addBinding(pattern.name.text, new Forall(typeVars, constraints, type), Symkind.Var); return type; } case SyntaxKind.StructPattern: { - const scheme = this.lookup(pattern.name.text); + const scheme = this.lookup(pattern.name.text, Symkind.Type); let recordType; if (scheme === null) { this.diagnostics.add(new BindingNotFoudDiagnostic(pattern.name.text, pattern.name)); @@ -1492,7 +1476,7 @@ export class Checker { case SyntaxKind.PunnedStructPatternField: { const fieldType = this.createTypeVar(); - this.addBinding(member.name.text, new Forall([], [], fieldType)); + this.addBinding(member.name.text, new Forall([], [], fieldType), Symkind.Var); this.addConstraint( new CEqual( new TLabeled(member.name.text, fieldType), @@ -1756,7 +1740,7 @@ export class Checker { const kindArgs = []; for (const varExpr of node.varExps) { const kindArg = this.createTypeVar(); - env.add(varExpr.text, new Forall([], [], kindArg)); + env.add(varExpr.text, new Forall([], [], kindArg), Symkind.Type); kindArgs.push(kindArg); } let elementTypes: Type[] = []; @@ -1769,7 +1753,7 @@ export class Checker { { const argTypes = member.elements.map(el => this.inferTypeExpression(el)); elementType = new TArrow(argTypes, TApp.build([ ...kindArgs, type ])); - parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType)); + parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType), Symkind.Var); break; } // TODO @@ -1780,7 +1764,7 @@ export class Checker { } } this.popContext(context); - parentEnv.add(node.name.text, new Forall(typeVars, constraints, type)); + parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type); break; } @@ -1796,13 +1780,13 @@ export class Checker { returnType: null, }; this.pushContext(context); - for (const varExpr of node.typeVars) { - env.add(varExpr.text, new Forall([], [], this.createTypeVar())); + for (const varExpr of node.varExps) { + env.add(varExpr.text, new Forall([], [], this.createTypeVar()), Symkind.Type); } const type = this.inferTypeExpression(node.typeExpression); this.popContext(context); const scheme = new Forall(typeVars, constraints, type); - parentEnv.add(node.name.text, scheme); + parentEnv.add(node.name.text, scheme, Symkind.Type); node.scheme = scheme; break; } @@ -1819,11 +1803,11 @@ export class Checker { returnType: null, }; this.pushContext(context); - const argTypes = []; - for (const varExpr of node.typeVars) { - const type = this.createTypeVar(); - env.add(varExpr.text, new Forall([], [], type)); - argTypes.push(type); + const kindArgs = []; + for (const varExpr of node.varExps) { + const kindArg = this.createTypeVar(); + env.add(varExpr.text, new Forall([], [], kindArg), Symkind.Type); + kindArgs.push(kindArg); } const fields = new Map(); if (node.members !== null) { @@ -1832,11 +1816,9 @@ export class Checker { } } this.popContext(context); - const type = new TRecord(node, argTypes, fields, node); - const scheme = new Forall(typeVars, constraints, type); - parentEnv.add(node.name.text, scheme); - node.tvs = argTypes; - node.scheme = scheme; //new Forall(typeVars, constraints, new TApp(type, argTypes)); + const type = new TRecord(node, [], fields, node); + parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type); + node.scheme = new Forall(typeVars, constraints, TApp.build([ ...kindArgs, type ])); break; } @@ -1867,17 +1849,17 @@ export class Checker { const b = this.createTypeVar(); const f = this.createTypeVar(); - env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b))); - env.add('String', new Forall([], [], this.stringType)); - env.add('Int', new Forall([], [], this.intType)); - env.add('True', new Forall([], [], this.boolType)); - env.add('False', new Forall([], [], this.boolType)); - env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); - env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); - env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); - env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); - env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType))); - env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType))); + env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b)), Symkind.Var); + env.add('String', new Forall([], [], this.stringType), Symkind.Type); + env.add('Int', new Forall([], [], this.intType), Symkind.Type); + env.add('True', new Forall([], [], this.boolType), Symkind.Var); + env.add('False', new Forall([], [], this.boolType), Symkind.Var); + env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); + env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); + env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); + env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var); + env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)), Symkind.Var); + env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)), Symkind.Var); const graph = new LabeledDirectedHashGraph(); this.addReferencesToGraph(graph, node, node); @@ -1958,7 +1940,7 @@ export class Checker { let ty2; if (node.pattern.kind === SyntaxKind.WrappedOperator) { ty2 = this.createTypeVar(); - this.addBinding(node.pattern.operator.text, new Forall([], [], ty2)); + this.addBinding(node.pattern.operator.text, new Forall([], [], ty2), Symkind.Var); } else { ty2 = this.inferBindings(node.pattern, typeVars, constraints); } @@ -1974,7 +1956,7 @@ export class Checker { && isFunctionDeclarationLike(element) && graph.hasEdge(node, element, false)) { assert(element.pattern.kind === SyntaxKind.BindPattern); - const scheme = this.lookup(element.pattern.name.text); + const scheme = this.lookup(element.pattern.name.text, Symkind.Var); assert(scheme !== null); this.instantiate(scheme, null); } else { @@ -2210,14 +2192,23 @@ export class Checker { return success; } - if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) { + let leftElement: Type = left; + while (leftElement.kind === TypeKind.App) { + leftElement = leftElement.right; + } + let rightElement: Type = right; + while (rightElement.kind === TypeKind.App) { + rightElement = rightElement.right; + } + + if (leftElement.kind === TypeKind.Record && right.kind === TypeKind.Labeled) { let success = true; if (right.fields === undefined) { right.fields = new Map([ [ right.name, right.type ] ]); } for (const [fieldName, fieldType] of right.fields) { - if (left.fields.has(fieldName)) { - if (!unify(fieldType, left.fields.get(fieldName)!)) { + if (leftElement.fields.has(fieldName)) { + if (!unify(fieldType, leftElement.fields.get(fieldName)!)) { success = false; } } else { diff --git a/src/cst.ts b/src/cst.ts index b5a8e54fe..e9bdafd83 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -1,4 +1,4 @@ -import { JSONObject, JSONValue } from "./util"; +import { JSONObject, JSONValue, MultiMap } from "./util"; import type { InferContext, Type, TypeEnv } from "./checker" export type TextSpan = [number, number]; @@ -210,13 +210,12 @@ function isNodeWithScope(node: Syntax): node is NodeWithScope { export const enum Symkind { Var = 1, Type = 2, - Constructor = 4, - Any = Var | Type | Constructor + Any = Var | Type } export class Scope { - private mapping = new Map(); + private mapping = new MultiMap(); public constructor( public node: NodeWithScope, @@ -236,7 +235,7 @@ export class Scope { } private add(name: string, node: Syntax, kind: Symkind): void { - this.mapping.set(name, [kind, node]); + this.mapping.add(name, [kind, node]); } private scan(node: Syntax): void { @@ -262,13 +261,14 @@ export class Scope { this.add(node.name.text, node, Symkind.Type); if (node.members !== null) { for (const member of node.members) { - this.add(member.name.text, member, Symkind.Constructor); + this.add(member.name.text, member, Symkind.Var); } } } case SyntaxKind.StructDeclaration: { - this.add(node.name.text, node, Symkind.Constructor | Symkind.Type); + this.add(node.name.text, node, Symkind.Type); + this.add(node.name.text, node, Symkind.Var); break; } case SyntaxKind.LetDeclaration: @@ -326,12 +326,10 @@ export class Scope { } } - public lookup(name: string, expectedKind = Symkind.Any): Syntax | null { + public lookup(name: string, expectedKind: Symkind = Symkind.Any): Syntax | null { let curr: Scope | null = this; do { - const match = curr.mapping.get(name); - if (match !== undefined) { - const [kind, decl] = match; + for (const [kind, decl] of curr.mapping.get(name)) { if (kind & expectedKind) { return decl; } @@ -1818,7 +1816,7 @@ export class StructDeclaration extends SyntaxBase { public pubKeyword: PubKeyword | null, public structKeyword: StructKeyword, public name: IdentifierAlt, - public typeVars: Identifier[], + public varExps: Identifier[], public members: StructDeclarationField[] | null, ) { super(); @@ -1946,7 +1944,7 @@ export class TypeDeclaration extends SyntaxBase { public pubKeyword: PubKeyword | null, public typeKeyword: TypeKeyword, public name: IdentifierAlt, - public typeVars: Identifier[], + public varExps: Identifier[], public equals: Equals, public typeExpression: TypeExpression ) { diff --git a/src/diagnostics.ts b/src/diagnostics.ts index 21bc33791..307060e4a 100644 --- a/src/diagnostics.ts +++ b/src/diagnostics.ts @@ -306,7 +306,7 @@ export class FieldMissingDiagnostic { public readonly level = Level.Error; public constructor( - public recordType: TRecord, + public recordType: Type, public fieldName: string, public node: Syntax | null, ) { diff --git a/src/test/type-inference.md b/src/test/type-inference.md index fd36db41f..3fc500a9a 100644 --- a/src/test/type-inference.md +++ b/src/test/type-inference.md @@ -112,3 +112,43 @@ let fac n. else. return n * fac (n-"foo") ``` + +## Enum-declarations are correctly typed + +``` +enum Maybe a. + Just a + Nothing + +let right_1 : Maybe Int = Just 1 +let right_2 : Maybe String = Just "foo" +let wrong : Maybe Int = Just "foo" +``` + +## Kind inference works + +``` +enum Maybe a. + Just a + Nothing + +let foo_1 : Maybe +let foo_2 : Maybe Int +let foo_3 : Maybe Int Int +``` + +## Can indirectly apply a polymorphic datatype to some type + +``` +enum Maybe a. + Just a + Nothing + +enum App a b. + MkApp (a b) + +enum Foo. + MkFoo (App Maybe Int) + +let f : Foo = MkFoo (MkApp (Just 1)) +``` diff --git a/src/util.ts b/src/util.ts index 893a6c04d..a905bc899 100644 --- a/src/util.ts +++ b/src/util.ts @@ -118,3 +118,71 @@ export abstract class BufferedStream { } +export class MultiMap { + + private mapping = new Map(); + + public get(key: K): V[] { + return this.mapping.get(key) ?? []; + } + + public add(key: K, value: V): void { + let elements = this.mapping.get(key); + if (elements === undefined) { + elements = []; + this.mapping.set(key, elements); + } + elements.push(value); + } + + public has(key: K, value?: V): boolean { + if (value === undefined) { + return this.mapping.has(key); + } + const elements = this.mapping.get(key); + if (elements === undefined) { + return false; + } + return elements.indexOf(value) !== -1; + } + + public keys(): Iterable { + return this.mapping.keys(); + } + + public *values(): Iterable { + for (const elements of this.mapping.values()) { + yield* elements; + } + } + + public *[Symbol.iterator](): Iterable<[K, V]> { + for (const [key, elements] of this.mapping) { + for (const value of elements) { + yield [key, value]; + } + } + } + + public delete(key: K, value?: V): number { + const elements = this.mapping.get(key); + if (elements === undefined) { + return 0; + } + if (value === undefined) { + this.mapping.delete(key); + return elements.length; + } + const i = elements.indexOf(value); + if (i !== -1) { + elements.splice(i, 1); + if (elements.length === 0) { + this.mapping.delete(key); + } + return 1; + } + return 0; + } + +} +