From 88e09052e6de5f7966afa91cf92a99d5102e30f2 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Mon, 5 Sep 2022 19:33:08 +0200 Subject: [PATCH] Fix support for nesting of scopes in type-checker --- src/checker.ts | 129 +++++++++++++++++++++++++++++++++++-------------- src/cst.ts | 12 ++++- 2 files changed, 102 insertions(+), 39 deletions(-) diff --git a/src/checker.ts b/src/checker.ts index 5705279a5..3fe07000e 100644 --- a/src/checker.ts +++ b/src/checker.ts @@ -321,7 +321,30 @@ class Forall extends SchemeBase { type Scheme = Forall -class TypeEnv extends Map { +class TypeEnv { + + private mapping = new Map(); + + public constructor(public parent: TypeEnv | null = null) { + + } + + public add(name: string, scheme: Scheme): void { + this.mapping.set(name, scheme); + } + + public lookup(name: string): Scheme | null { + let curr: TypeEnv | null = this; + do { + const scheme = curr.mapping.get(name); + if (scheme !== undefined) { + return scheme; + } + curr = curr.parent; + } while(curr !== null); + return null; + } + } export interface InferContext { @@ -407,14 +430,8 @@ export class Checker { } private lookup(name: string): Scheme | null { - for (let i = this.contexts.length-1; i >= 0; i--) { - const typeEnv = this.contexts[i].env; - const scheme = typeEnv.get(name); - if (scheme !== undefined) { - return scheme; - } - } - return null; + const context = this.contexts[this.contexts.length-1]; + return context.env.lookup(name); } private getReturnType(): Type { @@ -437,7 +454,7 @@ export class Checker { private addBinding(name: string, scheme: Scheme): void { const context = this.contexts[this.contexts.length-1]; - context.env.set(name, scheme); + context.env.add(name, scheme); } private forwardDeclare(node: Syntax): void { @@ -542,7 +559,7 @@ export class Checker { assert(node.name.modulePath.length === 0); const scope = node.getScope(); const target = scope.lookup(node.name.name.text); - if (target !== null && target.type !== undefined) { + if (target !== null && target.active) { return target.type; } const scheme = this.lookup(node.name.name.text); @@ -679,15 +696,8 @@ export class Checker { } case SyntaxKind.ReferenceExpression: { - // TODO only add references to nodes on the same level assert(node.name.modulePath.length === 0); - let target = node.getScope().lookup(node.name.name.text); - if (target !== null && target.kind === SyntaxKind.Param) { - target = target.parent!; - if (source !== null) { - graph.addEdge(target, source); - } - } + const target = node.getScope().lookup(node.name.name.text); if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) { graph.addEdge(source, target); } @@ -772,6 +782,41 @@ export class Checker { return graph; } + private initialize(node: Syntax, parentEnv: TypeEnv | null): void { + + switch (node.kind) { + + case SyntaxKind.SourceFile: + { + for (const element of node.elements) { + this.initialize(element, parentEnv); + } + break; + } + + case SyntaxKind.LetDeclaration: + { + const env = node.typeEnv = new TypeEnv(parentEnv); + if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) { + for (const element of node.body.elements) { + this.initialize(element, env); + } + } + break; + } + + case SyntaxKind.ExpressionStatement: + case SyntaxKind.ReturnStatement: + case SyntaxKind.StructDeclaration: + break; + + default: + throw new Error(`Unexpected ${node}`); + + } + + } + public check(node: SourceFile): void { const typeVars = new TVSet(); @@ -783,20 +828,24 @@ export class Checker { const a = this.createTypeVar(); - env.set('String', new Forall([], [], this.stringType)); - env.set('Int', new Forall([], [], this.intType)); - env.set('True', new Forall([], [], this.boolType)); - env.set('False', new Forall([], [], this.boolType)); - env.set('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); - env.set('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); - env.set('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); - env.set('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); - env.set('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType))); - env.set('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType))); + env.add('String', new Forall([], [], this.stringType)); + env.add('Int', new Forall([], [], this.intType)); + env.add('True', new Forall([], [], this.boolType)); + env.add('False', new Forall([], [], this.boolType)); + env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); + env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); + env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); + env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType))); + env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType))); + env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType))); const graph = this.computeReferenceGraph(node); - for (const nodes of strongconnect(graph)) { + this.initialize(node, env); + + const sccs = [...strongconnect(graph)]; + + for (const nodes of sccs) { const typeVars = new TVSet(); const constraints = new ConstraintSet(); @@ -805,7 +854,7 @@ export class Checker { assert(node.kind === SyntaxKind.LetDeclaration); - const env = new TypeEnv(); + const env = node.typeEnv!; const context: InferContext = { typeVars, constraints, @@ -843,6 +892,14 @@ export class Checker { this.inferBindings(node.pattern, type, typeVars, constraints); } + } + + for (const nodes of sccs) { + + for (const node of nodes) { + node.active = true; + } + for (const node of nodes) { assert(node.kind === SyntaxKind.LetDeclaration); @@ -867,7 +924,9 @@ export class Checker { case SyntaxKind.BlockBody: { for (const element of node.body.elements) { - this.infer(element); + if (element.kind !== SyntaxKind.LetDeclaration) { + this.infer(element); + } } break; } @@ -878,8 +937,7 @@ export class Checker { } for (const node of nodes) { - assert(node.kind === SyntaxKind.LetDeclaration); - delete node.type; + node.active = false; } } @@ -890,9 +948,6 @@ export class Checker { } } - //this.forwardDeclare(node); - //this.infer(node); - this.popContext(context); this.solve(new CMany(constraints), this.solution); diff --git a/src/cst.ts b/src/cst.ts index 36150ef89..d39134441 100644 --- a/src/cst.ts +++ b/src/cst.ts @@ -1,5 +1,5 @@ import { JSONObject, JSONValue } from "./util"; -import type { InferContext, Type } from "./checker" +import type { InferContext, Type, TypeEnv } from "./checker" export type TextSpan = [number, number]; @@ -230,7 +230,13 @@ export class Scope { for (const param of node.params) { this.scanPattern(param.pattern, param); } - if (node !== this.node) { + if (node === this.node) { + if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) { + for (const element of node.body.elements) { + this.scan(element); + } + } + } else { this.scanPattern(node.pattern, node); } break; @@ -1586,6 +1592,8 @@ export class LetDeclaration extends SyntaxBase { public scope?: Scope; public type?: Type; + public active?: boolean; + public typeEnv?: TypeEnv; public context?: InferContext; public constructor(