Fix support for nesting of scopes in type-checker

This commit is contained in:
Sam Vervaeck 2022-09-05 19:33:08 +02:00
parent 062ca46752
commit 88e09052e6
2 changed files with 102 additions and 39 deletions

View file

@ -321,7 +321,30 @@ class Forall extends SchemeBase {
type Scheme
= Forall
class TypeEnv extends Map<string, Scheme> {
class TypeEnv {
private mapping = new Map<string, Scheme>();
public constructor(public parent: TypeEnv | null = null) {
}
public add(name: string, scheme: Scheme): void {
this.mapping.set(name, scheme);
}
public lookup(name: string): Scheme | null {
let curr: TypeEnv | null = this;
do {
const scheme = curr.mapping.get(name);
if (scheme !== undefined) {
return scheme;
}
curr = curr.parent;
} while(curr !== null);
return null;
}
}
export interface InferContext {
@ -407,14 +430,8 @@ export class Checker {
}
private lookup(name: string): Scheme | null {
for (let i = this.contexts.length-1; i >= 0; i--) {
const typeEnv = this.contexts[i].env;
const scheme = typeEnv.get(name);
if (scheme !== undefined) {
return scheme;
}
}
return null;
const context = this.contexts[this.contexts.length-1];
return context.env.lookup(name);
}
private getReturnType(): Type {
@ -437,7 +454,7 @@ export class Checker {
private addBinding(name: string, scheme: Scheme): void {
const context = this.contexts[this.contexts.length-1];
context.env.set(name, scheme);
context.env.add(name, scheme);
}
private forwardDeclare(node: Syntax): void {
@ -542,7 +559,7 @@ export class Checker {
assert(node.name.modulePath.length === 0);
const scope = node.getScope();
const target = scope.lookup(node.name.name.text);
if (target !== null && target.type !== undefined) {
if (target !== null && target.active) {
return target.type;
}
const scheme = this.lookup(node.name.name.text);
@ -679,15 +696,8 @@ export class Checker {
}
case SyntaxKind.ReferenceExpression:
{
// TODO only add references to nodes on the same level
assert(node.name.modulePath.length === 0);
let target = node.getScope().lookup(node.name.name.text);
if (target !== null && target.kind === SyntaxKind.Param) {
target = target.parent!;
if (source !== null) {
graph.addEdge(target, source);
}
}
const target = node.getScope().lookup(node.name.name.text);
if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) {
graph.addEdge(source, target);
}
@ -772,6 +782,41 @@ export class Checker {
return graph;
}
private initialize(node: Syntax, parentEnv: TypeEnv | null): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
{
for (const element of node.elements) {
this.initialize(element, parentEnv);
}
break;
}
case SyntaxKind.LetDeclaration:
{
const env = node.typeEnv = new TypeEnv(parentEnv);
if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) {
for (const element of node.body.elements) {
this.initialize(element, env);
}
}
break;
}
case SyntaxKind.ExpressionStatement:
case SyntaxKind.ReturnStatement:
case SyntaxKind.StructDeclaration:
break;
default:
throw new Error(`Unexpected ${node}`);
}
}
public check(node: SourceFile): void {
const typeVars = new TVSet();
@ -783,20 +828,24 @@ export class Checker {
const a = this.createTypeVar();
env.set('String', new Forall([], [], this.stringType));
env.set('Int', new Forall([], [], this.intType));
env.set('True', new Forall([], [], this.boolType));
env.set('False', new Forall([], [], this.boolType));
env.set('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.set('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.set('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.set('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.set('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)));
env.set('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)));
env.add('String', new Forall([], [], this.stringType));
env.add('Int', new Forall([], [], this.intType));
env.add('True', new Forall([], [], this.boolType));
env.add('False', new Forall([], [], this.boolType));
env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)));
env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)));
const graph = this.computeReferenceGraph(node);
for (const nodes of strongconnect(graph)) {
this.initialize(node, env);
const sccs = [...strongconnect(graph)];
for (const nodes of sccs) {
const typeVars = new TVSet();
const constraints = new ConstraintSet();
@ -805,7 +854,7 @@ export class Checker {
assert(node.kind === SyntaxKind.LetDeclaration);
const env = new TypeEnv();
const env = node.typeEnv!;
const context: InferContext = {
typeVars,
constraints,
@ -843,6 +892,14 @@ export class Checker {
this.inferBindings(node.pattern, type, typeVars, constraints);
}
}
for (const nodes of sccs) {
for (const node of nodes) {
node.active = true;
}
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
@ -867,7 +924,9 @@ export class Checker {
case SyntaxKind.BlockBody:
{
for (const element of node.body.elements) {
this.infer(element);
if (element.kind !== SyntaxKind.LetDeclaration) {
this.infer(element);
}
}
break;
}
@ -878,8 +937,7 @@ export class Checker {
}
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
delete node.type;
node.active = false;
}
}
@ -890,9 +948,6 @@ export class Checker {
}
}
//this.forwardDeclare(node);
//this.infer(node);
this.popContext(context);
this.solve(new CMany(constraints), this.solution);

View file

@ -1,5 +1,5 @@
import { JSONObject, JSONValue } from "./util";
import type { InferContext, Type } from "./checker"
import type { InferContext, Type, TypeEnv } from "./checker"
export type TextSpan = [number, number];
@ -230,7 +230,13 @@ export class Scope {
for (const param of node.params) {
this.scanPattern(param.pattern, param);
}
if (node !== this.node) {
if (node === this.node) {
if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) {
for (const element of node.body.elements) {
this.scan(element);
}
}
} else {
this.scanPattern(node.pattern, node);
}
break;
@ -1586,6 +1592,8 @@ export class LetDeclaration extends SyntaxBase {
public scope?: Scope;
public type?: Type;
public active?: boolean;
public typeEnv?: TypeEnv;
public context?: InferContext;
public constructor(