diff --git a/src/checker.ts b/src/checker.ts index 3fe07000e..da5cc38e4 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -1,5 +1,6 @@ import { Expression, + LetDeclaration, Pattern, SourceFile, Syntax, @@ -321,7 +322,7 @@ class Forall extends SchemeBase { type Scheme = Forall -class TypeEnv { +export class TypeEnv { private mapping = new Map(); @@ -559,8 +560,8 @@ export class Checker { assert(node.name.modulePath.length === 0); const scope = node.getScope(); const target = scope.lookup(node.name.name.text); - if (target !== null && target.active) { - return target.type; + if (target !== null && target.kind === SyntaxKind.LetDeclaration && target.active) { + return target.type!; } const scheme = this.lookup(node.name.name.text); if (scheme === null) { @@ -681,105 +682,113 @@ export class Checker { } - private computeReferenceGraph(node: SourceFile): Graph { - const graph = new DirectedHashGraph(); - const visit = (node: Syntax, source: Syntax | null) => { - switch (node.kind) { - case SyntaxKind.ConstantExpression: - break; - case SyntaxKind.SourceFile: - { - for (const element of node.elements) { - visit(element, source); - } - break; + private addReferences(graph: Graph, node: Syntax, source: LetDeclaration | null) { + + switch (node.kind) { + + case SyntaxKind.ConstantExpression: + break; + + case SyntaxKind.SourceFile: + { + for (const element of node.elements) { + this.addReferences(graph, element, source); } - case SyntaxKind.ReferenceExpression: - { - assert(node.name.modulePath.length === 0); - const target = node.getScope().lookup(node.name.name.text); - if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) { - graph.addEdge(source, target); - } - break; - } - case SyntaxKind.NamedTupleExpression: - { - for (const arg of node.elements) { - visit(arg, source); - } - break; - } - case SyntaxKind.NestedExpression: - { - visit(node.expression, source); - break; - } - case SyntaxKind.InfixExpression: - { - visit(node.left, source); - visit(node.right, source); - break; - } - case SyntaxKind.CallExpression: - { - visit(node.func, source); - for (const arg of node.args) { - visit(arg, source); - } - break; - } - case SyntaxKind.IfStatement: - { - for (const cs of node.cases) { - if (cs.test !== null) { - visit(cs.test, source); - } - for (const element of cs.elements) { - visit(element, source); - } - } - break; - } - case SyntaxKind.ExpressionStatement: - { - visit(node.expression, source); - break; - } - case SyntaxKind.ReturnStatement: - { - if (node.expression !== null) { - visit(node.expression, source); - } - break; - } - case SyntaxKind.LetDeclaration: - { - graph.addVertex(node); - if (node.body !== null) { - switch (node.body.kind) { - case SyntaxKind.ExprBody: - { - visit(node.body.expression, node); - break; - } - case SyntaxKind.BlockBody: - { - for (const element of node.body.elements) { - visit(element, node); - } - break; - } - } - } - break; - } - default: - throw new Error(`Unexpected ${node.constructor.name}`); + break; } + + case SyntaxKind.ReferenceExpression: + { + assert(node.name.modulePath.length === 0); + const target = node.getScope().lookup(node.name.name.text); + if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) { + graph.addEdge(source, target); + } + break; + } + + case SyntaxKind.NamedTupleExpression: + { + for (const arg of node.elements) { + this.addReferences(graph, arg, source); + } + break; + } + + case SyntaxKind.NestedExpression: + { + this.addReferences(graph, node.expression, source); + break; + } + + case SyntaxKind.InfixExpression: + { + this.addReferences(graph, node.left, source); + this.addReferences(graph, node.right, source); + break; + } + + case SyntaxKind.CallExpression: + { + this.addReferences(graph, node.func, source); + for (const arg of node.args) { + this.addReferences(graph, arg, source); + } + break; + } + + case SyntaxKind.IfStatement: + { + for (const cs of node.cases) { + if (cs.test !== null) { + this.addReferences(graph, cs.test, source); + } + for (const element of cs.elements) { + this.addReferences(graph, element, source); + } + } + break; + } + + case SyntaxKind.ExpressionStatement: + { + this.addReferences(graph, node.expression, source); + break; + } + case SyntaxKind.ReturnStatement: + { + if (node.expression !== null) { + this.addReferences(graph, node.expression, source); + } + break; + } + case SyntaxKind.LetDeclaration: + { + graph.addVertex(node); + if (node.body !== null) { + switch (node.body.kind) { + case SyntaxKind.ExprBody: + { + this.addReferences(graph, node.body.expression, node); + break; + } + case SyntaxKind.BlockBody: + { + for (const element of node.body.elements) { + this.addReferences(graph, element, node); + } + break; + } + } + } + break; + } + + default: + throw new Error(`Unexpected ${node.constructor.name}`); + } - visit(node, null); - return graph; + } private initialize(node: Syntax, parentEnv: TypeEnv | null): void { @@ -839,7 +848,8 @@ export class Checker { env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType))); env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType))); - const graph = this.computeReferenceGraph(node); + const graph = new DirectedHashGraph(); + this.addReferences(graph, node, null); this.initialize(node, env); @@ -852,8 +862,6 @@ export class Checker { for (const node of nodes) { - assert(node.kind === SyntaxKind.LetDeclaration); - const env = node.typeEnv!; const context: InferContext = { typeVars, @@ -902,8 +910,6 @@ export class Checker { for (const node of nodes) { - assert(node.kind === SyntaxKind.LetDeclaration); - const context = node.context!; const returnType = context.returnType!; this.contexts.push(context);