Clean up some code in checker.ts

This commit is contained in:
Sam Vervaeck 2022-09-05 19:38:55 +02:00
parent 88e09052e6
commit 70f9f99181

View file

@ -1,5 +1,6 @@
import { import {
Expression, Expression,
LetDeclaration,
Pattern, Pattern,
SourceFile, SourceFile,
Syntax, Syntax,
@ -321,7 +322,7 @@ class Forall extends SchemeBase {
type Scheme type Scheme
= Forall = Forall
class TypeEnv { export class TypeEnv {
private mapping = new Map<string, Scheme>(); private mapping = new Map<string, Scheme>();
@ -559,8 +560,8 @@ export class Checker {
assert(node.name.modulePath.length === 0); assert(node.name.modulePath.length === 0);
const scope = node.getScope(); const scope = node.getScope();
const target = scope.lookup(node.name.name.text); const target = scope.lookup(node.name.name.text);
if (target !== null && target.active) { if (target !== null && target.kind === SyntaxKind.LetDeclaration && target.active) {
return target.type; return target.type!;
} }
const scheme = this.lookup(node.name.name.text); const scheme = this.lookup(node.name.name.text);
if (scheme === null) { if (scheme === null) {
@ -681,19 +682,21 @@ export class Checker {
} }
private computeReferenceGraph(node: SourceFile): Graph<Syntax> { private addReferences(graph: Graph<LetDeclaration>, node: Syntax, source: LetDeclaration | null) {
const graph = new DirectedHashGraph<Syntax>();
const visit = (node: Syntax, source: Syntax | null) => {
switch (node.kind) { switch (node.kind) {
case SyntaxKind.ConstantExpression: case SyntaxKind.ConstantExpression:
break; break;
case SyntaxKind.SourceFile: case SyntaxKind.SourceFile:
{ {
for (const element of node.elements) { for (const element of node.elements) {
visit(element, source); this.addReferences(graph, element, source);
} }
break; break;
} }
case SyntaxKind.ReferenceExpression: case SyntaxKind.ReferenceExpression:
{ {
assert(node.name.modulePath.length === 0); assert(node.name.modulePath.length === 0);
@ -703,53 +706,59 @@ export class Checker {
} }
break; break;
} }
case SyntaxKind.NamedTupleExpression: case SyntaxKind.NamedTupleExpression:
{ {
for (const arg of node.elements) { for (const arg of node.elements) {
visit(arg, source); this.addReferences(graph, arg, source);
} }
break; break;
} }
case SyntaxKind.NestedExpression: case SyntaxKind.NestedExpression:
{ {
visit(node.expression, source); this.addReferences(graph, node.expression, source);
break; break;
} }
case SyntaxKind.InfixExpression: case SyntaxKind.InfixExpression:
{ {
visit(node.left, source); this.addReferences(graph, node.left, source);
visit(node.right, source); this.addReferences(graph, node.right, source);
break; break;
} }
case SyntaxKind.CallExpression: case SyntaxKind.CallExpression:
{ {
visit(node.func, source); this.addReferences(graph, node.func, source);
for (const arg of node.args) { for (const arg of node.args) {
visit(arg, source); this.addReferences(graph, arg, source);
} }
break; break;
} }
case SyntaxKind.IfStatement: case SyntaxKind.IfStatement:
{ {
for (const cs of node.cases) { for (const cs of node.cases) {
if (cs.test !== null) { if (cs.test !== null) {
visit(cs.test, source); this.addReferences(graph, cs.test, source);
} }
for (const element of cs.elements) { for (const element of cs.elements) {
visit(element, source); this.addReferences(graph, element, source);
} }
} }
break; break;
} }
case SyntaxKind.ExpressionStatement: case SyntaxKind.ExpressionStatement:
{ {
visit(node.expression, source); this.addReferences(graph, node.expression, source);
break; break;
} }
case SyntaxKind.ReturnStatement: case SyntaxKind.ReturnStatement:
{ {
if (node.expression !== null) { if (node.expression !== null) {
visit(node.expression, source); this.addReferences(graph, node.expression, source);
} }
break; break;
} }
@ -760,13 +769,13 @@ export class Checker {
switch (node.body.kind) { switch (node.body.kind) {
case SyntaxKind.ExprBody: case SyntaxKind.ExprBody:
{ {
visit(node.body.expression, node); this.addReferences(graph, node.body.expression, node);
break; break;
} }
case SyntaxKind.BlockBody: case SyntaxKind.BlockBody:
{ {
for (const element of node.body.elements) { for (const element of node.body.elements) {
visit(element, node); this.addReferences(graph, element, node);
} }
break; break;
} }
@ -774,12 +783,12 @@ export class Checker {
} }
break; break;
} }
default: default:
throw new Error(`Unexpected ${node.constructor.name}`); throw new Error(`Unexpected ${node.constructor.name}`);
} }
}
visit(node, null);
return graph;
} }
private initialize(node: Syntax, parentEnv: TypeEnv | null): void { private initialize(node: Syntax, parentEnv: TypeEnv | null): void {
@ -839,7 +848,8 @@ export class Checker {
env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType))); env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)));
env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType))); env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)));
const graph = this.computeReferenceGraph(node); const graph = new DirectedHashGraph<LetDeclaration>();
this.addReferences(graph, node, null);
this.initialize(node, env); this.initialize(node, env);
@ -852,8 +862,6 @@ export class Checker {
for (const node of nodes) { for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
const env = node.typeEnv!; const env = node.typeEnv!;
const context: InferContext = { const context: InferContext = {
typeVars, typeVars,
@ -902,8 +910,6 @@ export class Checker {
for (const node of nodes) { for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
const context = node.context!; const context = node.context!;
const returnType = context.returnType!; const returnType = context.returnType!;
this.contexts.push(context); this.contexts.push(context);