Clean up some code in checker.ts

This commit is contained in:
Sam Vervaeck 2022-09-05 19:38:55 +02:00
parent 88e09052e6
commit 70f9f99181

View file

@ -1,5 +1,6 @@
import {
Expression,
LetDeclaration,
Pattern,
SourceFile,
Syntax,
@ -321,7 +322,7 @@ class Forall extends SchemeBase {
type Scheme
= Forall
class TypeEnv {
export class TypeEnv {
private mapping = new Map<string, Scheme>();
@ -559,8 +560,8 @@ export class Checker {
assert(node.name.modulePath.length === 0);
const scope = node.getScope();
const target = scope.lookup(node.name.name.text);
if (target !== null && target.active) {
return target.type;
if (target !== null && target.kind === SyntaxKind.LetDeclaration && target.active) {
return target.type!;
}
const scheme = this.lookup(node.name.name.text);
if (scheme === null) {
@ -681,105 +682,113 @@ export class Checker {
}
private computeReferenceGraph(node: SourceFile): Graph<Syntax> {
const graph = new DirectedHashGraph<Syntax>();
const visit = (node: Syntax, source: Syntax | null) => {
switch (node.kind) {
case SyntaxKind.ConstantExpression:
break;
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
visit(element, source);
}
break;
private addReferences(graph: Graph<LetDeclaration>, node: Syntax, source: LetDeclaration | null) {
switch (node.kind) {
case SyntaxKind.ConstantExpression:
break;
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.addReferences(graph, element, source);
}
case SyntaxKind.ReferenceExpression:
{
assert(node.name.modulePath.length === 0);
const target = node.getScope().lookup(node.name.name.text);
if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) {
graph.addEdge(source, target);
}
break;
}
case SyntaxKind.NamedTupleExpression:
{
for (const arg of node.elements) {
visit(arg, source);
}
break;
}
case SyntaxKind.NestedExpression:
{
visit(node.expression, source);
break;
}
case SyntaxKind.InfixExpression:
{
visit(node.left, source);
visit(node.right, source);
break;
}
case SyntaxKind.CallExpression:
{
visit(node.func, source);
for (const arg of node.args) {
visit(arg, source);
}
break;
}
case SyntaxKind.IfStatement:
{
for (const cs of node.cases) {
if (cs.test !== null) {
visit(cs.test, source);
}
for (const element of cs.elements) {
visit(element, source);
}
}
break;
}
case SyntaxKind.ExpressionStatement:
{
visit(node.expression, source);
break;
}
case SyntaxKind.ReturnStatement:
{
if (node.expression !== null) {
visit(node.expression, source);
}
break;
}
case SyntaxKind.LetDeclaration:
{
graph.addVertex(node);
if (node.body !== null) {
switch (node.body.kind) {
case SyntaxKind.ExprBody:
{
visit(node.body.expression, node);
break;
}
case SyntaxKind.BlockBody:
{
for (const element of node.body.elements) {
visit(element, node);
}
break;
}
}
}
break;
}
default:
throw new Error(`Unexpected ${node.constructor.name}`);
break;
}
case SyntaxKind.ReferenceExpression:
{
assert(node.name.modulePath.length === 0);
const target = node.getScope().lookup(node.name.name.text);
if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) {
graph.addEdge(source, target);
}
break;
}
case SyntaxKind.NamedTupleExpression:
{
for (const arg of node.elements) {
this.addReferences(graph, arg, source);
}
break;
}
case SyntaxKind.NestedExpression:
{
this.addReferences(graph, node.expression, source);
break;
}
case SyntaxKind.InfixExpression:
{
this.addReferences(graph, node.left, source);
this.addReferences(graph, node.right, source);
break;
}
case SyntaxKind.CallExpression:
{
this.addReferences(graph, node.func, source);
for (const arg of node.args) {
this.addReferences(graph, arg, source);
}
break;
}
case SyntaxKind.IfStatement:
{
for (const cs of node.cases) {
if (cs.test !== null) {
this.addReferences(graph, cs.test, source);
}
for (const element of cs.elements) {
this.addReferences(graph, element, source);
}
}
break;
}
case SyntaxKind.ExpressionStatement:
{
this.addReferences(graph, node.expression, source);
break;
}
case SyntaxKind.ReturnStatement:
{
if (node.expression !== null) {
this.addReferences(graph, node.expression, source);
}
break;
}
case SyntaxKind.LetDeclaration:
{
graph.addVertex(node);
if (node.body !== null) {
switch (node.body.kind) {
case SyntaxKind.ExprBody:
{
this.addReferences(graph, node.body.expression, node);
break;
}
case SyntaxKind.BlockBody:
{
for (const element of node.body.elements) {
this.addReferences(graph, element, node);
}
break;
}
}
}
break;
}
default:
throw new Error(`Unexpected ${node.constructor.name}`);
}
visit(node, null);
return graph;
}
private initialize(node: Syntax, parentEnv: TypeEnv | null): void {
@ -839,7 +848,8 @@ export class Checker {
env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)));
env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)));
const graph = this.computeReferenceGraph(node);
const graph = new DirectedHashGraph<LetDeclaration>();
this.addReferences(graph, node, null);
this.initialize(node, env);
@ -852,8 +862,6 @@ export class Checker {
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
const env = node.typeEnv!;
const context: InferContext = {
typeVars,
@ -902,8 +910,6 @@ export class Checker {
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
const context = node.context!;
const returnType = context.returnType!;
this.contexts.push(context);