Make recursive function definitions work

This commit is contained in:
Sam Vervaeck 2022-09-05 17:25:55 +02:00
parent 20af138fa5
commit 062ca46752
3 changed files with 147 additions and 168 deletions

View file

@ -1,15 +1,14 @@
import {
Expression,
LetDeclaration,
Pattern,
SourceFile,
Syntax,
SyntaxKind,
TypeExpression
} from "./cst";
import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, Diagnostics, UnificationFailedDiagnostic } from "./diagnostics";
import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, describeType, Diagnostics, UnificationFailedDiagnostic } from "./diagnostics";
import { assert } from "./util";
import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl"
import { DirectedHashGraph, Graph, strongconnect } from "yagl"
export enum TypeKind {
Arrow,
@ -53,7 +52,9 @@ class TVar extends TypeBase {
}
public substitute(sub: TVSub): Type {
return sub.get(this) ?? this;
const other = sub.get(this);
return other === undefined
? this : other.substitute(sub);
}
}
@ -269,6 +270,10 @@ class CEqual extends ConstraintBase {
);
}
public dump(): void {
console.error(`${describeType(this.left)} ~ ${describeType(this.right)}`);
}
}
class CMany extends ConstraintBase {
@ -323,7 +328,7 @@ export interface InferContext {
typeVars: TVSet;
env: TypeEnv;
constraints: ConstraintSet;
returnType: Type;
returnType: Type | null;
}
export class Checker {
@ -331,17 +336,16 @@ export class Checker {
private nextTypeVarId = 0;
private nextConTypeId = 0;
private graph?: LabeledGraph<Syntax, Syntax>;
private currentCycle?: Map<Syntax, Type>;
//private graph?: Graph<Syntax>;
//private currentCycle?: Map<Syntax, Type>;
private stringType = new TCon(this.nextConTypeId++, [], 'String');
private intType = new TCon(this.nextConTypeId++, [], 'Int');
private boolType = new TCon(this.nextConTypeId++, [], 'Bool');
private typeEnvs: TypeEnv[] = [];
private typeVars: TVSet[] = [];
private constraints: ConstraintSet[] = [];
private returnTypes: Type[] = [];
private contexts: InferContext[] = [];
private solution = new TVSub();
public constructor(
private diagnostics: Diagnostics
@ -363,7 +367,8 @@ export class Checker {
private createTypeVar(): TVar {
const typeVar = new TVar(this.nextTypeVarId++);
this.typeVars[this.typeVars.length-1].add(typeVar);
const context = this.contexts[this.contexts.length-1];
context.typeVars.add(typeVar);
return typeVar;
}
@ -378,54 +383,33 @@ export class Checker {
}
case ConstraintKind.Equal:
{
const count = this.constraints.length;
for (let i = count-1; i > 0; i--) {
const typeVars = this.typeVars[i];
const constraints = this.constraints[i];
const count = this.contexts.length;
let i;
for (i = count-1; i > 0; i--) {
const typeVars = this.contexts[i].typeVars;
if (typeVars.intersectsType(constraint.left) || typeVars.intersectsType(constraint.right)) {
constraints.push(constraint);
return;
break;
}
}
this.constraints[0].push(constraint);
return;
this.contexts[i].constraints.push(constraint);
break;
}
}
}
private pushContext(context: InferContext) {
if (context.typeVars !== null) {
this.typeVars.push(context.typeVars);
}
if (context.env !== null) {
this.typeEnvs.push(context.env);
}
if (context.constraints !== null) {
this.constraints.push(context.constraints);
}
if (context.returnType !== null) {
this.returnTypes.push(context.returnType);
}
this.contexts.push(context);
}
private popContext(context: InferContext) {
if (context.typeVars !== null) {
this.typeVars.pop();
}
if (context.env !== null) {
this.typeEnvs.pop();
}
if (context.constraints !== null) {
this.constraints.pop();
}
if (context.returnType !== null) {
this.returnTypes.pop();
}
assert(this.contexts[this.contexts.length-1] === context);
this.contexts.pop();
}
private lookup(name: string): Scheme | null {
for (let i = this.typeEnvs.length-1; i >= 0; i--) {
const scheme = this.typeEnvs[i].get(name);
for (let i = this.contexts.length-1; i >= 0; i--) {
const typeEnv = this.contexts[i].env;
const scheme = typeEnv.get(name);
if (scheme !== undefined) {
return scheme;
}
@ -434,8 +418,9 @@ export class Checker {
}
private getReturnType(): Type {
assert(this.returnTypes.length > 0);
return this.returnTypes[this.returnTypes.length-1];
const context = this.contexts[this.contexts.length-1];
assert(context && context.returnType !== null);
return context.returnType;
}
private instantiate(scheme: Scheme): Type {
@ -445,13 +430,14 @@ export class Checker {
}
for (const constraint of scheme.constraints) {
this.addConstraint(constraint.substitute(sub));
// TODO keep record of a 'chain' of instantiations so that the diagnostics tool can output it on type error
}
return scheme.type.substitute(sub);
}
private addBinding(name: string, scheme: Scheme): void {
const env = this.typeEnvs[this.typeEnvs.length-1];
env.set(name, scheme);
const context = this.contexts[this.contexts.length-1];
context.env.set(name, scheme);
}
private forwardDeclare(node: Syntax): void {
@ -474,36 +460,7 @@ export class Checker {
}
case SyntaxKind.LetDeclaration:
{
const typeVars = new TVSet();
const env = new TypeEnv();
const constraints = new ConstraintSet();
const returnType = this.createTypeVar();
const context = { typeVars, env, constraints, returnType };
node.context = context;
this.pushContext(context);
let type;
if (node.typeAssert !== null) {
type = this.inferTypeExpression(node.typeAssert.typeExpression);
} else {
type = this.createTypeVar();
}
node.type = type;
if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) {
for (const element of node.body.elements) {
this.forwardDeclare(element);
}
}
this.popContext(context);
this.inferBindings(node.pattern, type, context.typeVars, context.constraints);
break;
}
}
}
@ -564,56 +521,8 @@ export class Checker {
}
case SyntaxKind.LetDeclaration:
{
// Get the type that was stored on the node by forwardDeclare()
const type = node.type!;
const context = node.context!;
this.pushContext(context);
const paramTypes = [];
const returnType = context.returnType;
for (const param of node.params) {
const paramType = this.createTypeVar()
this.inferBindings(param.pattern, paramType, [], []);
paramTypes.push(paramType);
}
if (node.body !== null) {
switch (node.body.kind) {
case SyntaxKind.ExprBody:
{
this.addConstraint(
new CEqual(
this.inferExpression(node.body.expression),
returnType,
node.body.expression
)
);
break;
}
case SyntaxKind.BlockBody:
{
for (const element of node.body.elements) {
this.infer(element);
}
break;
}
}
}
this.addConstraint(new CEqual(type, new TArrow(paramTypes, returnType), node));
this.popContext(context);
// FIXME these two may need to go below inferBindings
//this.typeVars.pop();
//this.constraints.pop();
break;
}
default:
throw new Error(`Unexpected ${node}`);
@ -631,22 +540,17 @@ export class Checker {
case SyntaxKind.ReferenceExpression:
{
assert(node.name.modulePath.length === 0);
const target = node.getScope().lookup(node.name.name.text) as LetDeclaration;
if (target === node.getScope().node) {
return target.type!;
}
const targetType = this.currentCycle.get(target);
if (targetType) {
return targetType;
const scope = node.getScope();
const target = scope.lookup(node.name.name.text);
if (target !== null && target.type !== undefined) {
return target.type;
}
const scheme = this.lookup(node.name.name.text);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name));
return new TAny();
}
const type = this.instantiate(scheme);
this.currentCycle.set(target, type);
return type;
return this.instantiate(scheme);
}
case SyntaxKind.CallExpression:
@ -760,8 +664,8 @@ export class Checker {
}
private computeReferenceGraph(node: SourceFile): LabeledGraph<Syntax, Syntax> {
const graph = new LabeledDirectedHashGraph<Syntax, Syntax>();
private computeReferenceGraph(node: SourceFile): Graph<Syntax> {
const graph = new DirectedHashGraph<Syntax>();
const visit = (node: Syntax, source: Syntax | null) => {
switch (node.kind) {
case SyntaxKind.ConstantExpression:
@ -777,9 +681,15 @@ export class Checker {
{
// TODO only add references to nodes on the same level
assert(node.name.modulePath.length === 0);
const target = node.getScope().lookup(node.name.name.text);
let target = node.getScope().lookup(node.name.name.text);
if (target !== null && target.kind === SyntaxKind.Param) {
target = target.parent!;
if (source !== null) {
graph.addEdge(target, source);
}
}
if (source !== null && target !== null && target.kind === SyntaxKind.LetDeclaration) {
graph.addEdge(source, target, node);
graph.addEdge(source, target);
}
break;
}
@ -864,19 +774,14 @@ export class Checker {
public check(node: SourceFile): void {
this.graph = this.computeReferenceGraph(node);
const typeVars = new TVSet();
const constraints = new ConstraintSet();
const env = new TypeEnv();
const context: InferContext = { typeVars, constraints, env, returnType: null };
this.typeVars.push(typeVars);
this.constraints.push(constraints);
this.typeEnvs.push(env);
this.pushContext(context);
const a = this.createTypeVar();
const b = this.createTypeVar();
const d = this.createTypeVar();
env.set('String', new Forall([], [], this.stringType));
env.set('Int', new Forall([], [], this.intType));
@ -889,38 +794,113 @@ export class Checker {
env.set('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)));
env.set('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)));
//this.infer(node);
for (const node of this.graph.getVertices()) {
this.forwardDeclare(node);
}
for (const nodes of strongconnect(this.graph)) {
this.currentCycle = new Map();
const graph = this.computeReferenceGraph(node);
for (const nodes of strongconnect(graph)) {
const typeVars = new TVSet();
const constraints = new ConstraintSet();
for (const node of nodes) {
for (const node of nodes) {
this.currentCycle.set(node, null);
assert(node.kind === SyntaxKind.LetDeclaration);
const env = new TypeEnv();
const context: InferContext = {
typeVars,
constraints,
env,
returnType: null,
};
node.context = context;
this.contexts.push(context);
const returnType = this.createTypeVar();
context.returnType = returnType;
const paramTypes = [];
for (const param of node.params) {
const paramType = this.createTypeVar()
this.inferBindings(param.pattern, paramType, [], []);
paramTypes.push(paramType);
}
this.infer(node);
let type = new TArrow(paramTypes, returnType);
if (node.typeAssert !== null) {
this.addConstraint(
new CEqual(
this.inferTypeExpression(node.typeAssert.typeExpression),
type,
node.typeAssert
)
);
}
node.type = type;
this.contexts.pop();
this.inferBindings(node.pattern, type, typeVars, constraints);
}
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
const context = node.context!;
const returnType = context.returnType!;
this.contexts.push(context);
if (node.body !== null) {
switch (node.body.kind) {
case SyntaxKind.ExprBody:
{
this.addConstraint(
new CEqual(
this.inferExpression(node.body.expression),
returnType,
node.body.expression
)
);
break;
}
case SyntaxKind.BlockBody:
{
for (const element of node.body.elements) {
this.infer(element);
}
break;
}
}
}
this.contexts.pop();
}
for (const node of nodes) {
assert(node.kind === SyntaxKind.LetDeclaration);
delete node.type;
}
}
this.currentCycle = new Map();
for (const element of node.elements) {
if (element.kind !== SyntaxKind.LetDeclaration) {
//this.forwardDeclare(element);
this.infer(element);
}
}
this.typeVars.pop();
this.constraints.pop();
this.typeEnvs.pop();
//this.forwardDeclare(node);
//this.infer(node);
this.solve(new CMany(constraints));
this.popContext(context);
this.solve(new CMany(constraints), this.solution);
}
private solve(constraint: Constraint): TVSub {
private solve(constraint: Constraint, solution: TVSub): void {
const queue = [ constraint ];
const solution = new TVSub();
while (queue.length > 0) {
@ -953,8 +933,6 @@ export class Checker {
}
return solution;
}
private unify(left: Type, right: Type, solution: TVSub): boolean {

View file

@ -172,6 +172,7 @@ export type Syntax
| Param
| Body
| StructDeclarationField
| TypeAssert
| Declaration
| Statement
| Expression

View file

@ -145,7 +145,7 @@ export class BindingNotFoudDiagnostic {
}
function describeType(type: Type): string {
export function describeType(type: Type): string {
switch (type.kind) {
case TypeKind.Any:
return 'Any';