From e92e346bade1077df4ea9727431731e98996db3c Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sat, 12 Aug 2023 13:46:19 +0200 Subject: [PATCH] Decouple type checking info from CST and refactor checker.ts a bit --- README.md | 16 ++ compiler/src/checker.ts | 579 +++++++++++++++++++++++----------------- compiler/src/cst.ts | 41 +-- compiler/src/types.ts | 13 +- 4 files changed, 365 insertions(+), 284 deletions(-) diff --git a/README.md b/README.md index 34860a67a..8610ffd4a 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,22 @@ nice goodies, including: - **Cross-platform standard library**, allowing you to write your code for the web and the desktop at the same time. +## Examples + +_Note that these examples are stil in the design phase and not able to compile._ + +```ocaml +import "html" ( HtmlComponent ) + +let app : HtmlComponent. + + let { fullname, .. } = perform get_app_state + + return match user. + None => h1 [ "Please log in." ] + Some({ fullname, .. }) => +``` + ## Core Principles Bolt has a few fundamental design principles that we hope in time will make it diff --git a/compiler/src/checker.ts b/compiler/src/checker.ts index 07782ced3..765552bff 100644 --- a/compiler/src/checker.ts +++ b/compiler/src/checker.ts @@ -30,8 +30,10 @@ import { import { assert, assertNever, isEmpty, MultiMap, toStringTag, InspectFn } from "./util"; import { Analyser } from "./analysis"; import { InspectOptions } from "util"; -import { TypeKind, TApp, TArrow, TCon, TField, TNil, TPresent, TUniVar, TVSet, TVSub, Type, TypeBase, TAbsent, TRigidVar, TVar, buildTupleTypeWithLoc, buildTupleType } from "./types"; +import { TypeKind, TApp, TArrow, TCon, TField, TNil, TPresent, TRegularVar, TVSet, TVSub, Type, TypeBase, TAbsent, TRigidVar, TVar, buildTupleTypeWithLoc, buildTupleType } from "./types"; import { CEmpty, CEqual, CMany, Constraint, ConstraintKind, ConstraintSet } from "./constraints"; +import { warn } from "console"; +import { wrap } from "module"; // export class Qual { @@ -214,7 +216,7 @@ class Forall extends SchemeBase { return new Forall(new TVSet, new CEmpty, type); } - public static fromArrays(typeVars: TUniVar[], constraints: Constraint[], type: Type): Forall { + public static fromArrays(typeVars: TRegularVar[], constraints: Constraint[], type: Type): Forall { return new Forall(new TVSet(typeVars), new CMany(constraints), type); } @@ -327,11 +329,24 @@ function splitReferences(node: NodeWithReference): [IdentifierAlt[], Identifier return [modulePath, name] } -export interface InferContext { - typeVars: TVSet; - env: TypeEnv; - constraints: ConstraintSet; - returnType: Type | null; +class PolyContext { + + public constructor( + public typeVars = new TVSet(), + public constraints: ConstraintSet = [], + ) { + + } + +} + +export interface TCInfo { + inferredType?: Type; + inferredKind?: Kind; + poly?: PolyContext; + kindEnv?: KindEnv; + typeEnv?: TypeEnv; + returnType?: Type | null; } function isSignatureDeclarationLike(node: LetDeclaration): boolean { @@ -348,13 +363,6 @@ function isFunctionDeclarationLike(node: LetDeclaration): boolean { // && (node.params.length > 0 || (node.body !== null && node.body.kind === SyntaxKind.BlockBody)); } -function* integers(start: number = 0) { - let i = start; - for (;;) { - yield i++; - } -} - function hasTypeVar(typeVars: TVSet, type: Type): boolean { for (const tv of type.getTypeVars()) { if (typeVars.has(tv)) { @@ -366,7 +374,7 @@ function hasTypeVar(typeVars: TVSet, type: Type): boolean { export class Checker { - private typeVarIds = integers(); + private nextTypeVarId = 0; private nextKindVarId = 0; private nextConTypeId = 0; @@ -375,15 +383,17 @@ export class Checker { private boolType = this.createTCon('Bool'); private unitType = buildTupleType([]); - private contexts: InferContext[] = []; - private classDecls = new Map(); private globalKindEnv = new KindEnv(); private globalTypeEnv = new TypeEnv(); - private solution = new TVSub(); + private typeSolution = new TVSub(); private kindSolution = new KVSub(); + private typeEnvStack: TypeEnv[] = []; + private polyContextStack: PolyContext[] = []; + private returnTypeStack: (Type | null)[] = []; + public constructor( private analyser: Analyser, private diagnostics: Diagnostics @@ -393,8 +403,8 @@ export class Checker { this.globalKindEnv.set('String', kindOfTypes); this.globalKindEnv.set('Bool', kindOfTypes); - const a = new TUniVar(this.typeVarIds.next().value!); - const b = new TUniVar(this.typeVarIds.next().value!); + const a = new TRegularVar(this.nextTypeVarId++); + const b = new TRegularVar(this.nextTypeVarId++); this.globalTypeEnv.add('$', Forall.fromArrays([ a, b ], [], new TArrow(new TArrow(new TArrow(a, b), a), b)), Symkind.Var); this.globalTypeEnv.add('String', Forall.fromArrays([], [], this.stringType), Symkind.Type); @@ -427,76 +437,116 @@ export class Checker { return new TCon(this.nextConTypeId++, name, node); } - private createTypeVar(node: Syntax | null = null): TUniVar { - const typeVar = new TUniVar(this.typeVarIds.next().value!, node); - this.getContext().typeVars.add(typeVar); + private getInfo(node: Syntax): TCInfo { + return node as unknown as TCInfo; + } + + private getPolyContext(): PolyContext { + return this.polyContextStack[this.polyContextStack.length-1]; + } + + private pushInfo(info: TCInfo): void { + if (info.poly !== undefined) { + this.polyContextStack.push(info.poly); + } + if (info.returnType !== undefined) { + this.returnTypeStack.push(info.returnType); + } + if (info.typeEnv !== undefined) { + this.typeEnvStack.push(info.typeEnv); + } + } + + private popInfo(info: TCInfo): void { + if (info.poly !== undefined) { + this.polyContextStack.pop(); + } + if (info.returnType !== undefined) { + this.returnTypeStack.pop(); + } + if (info.typeEnv !== undefined) { + this.typeEnvStack.pop(); + } + } + + public getReturnType(): Type { + const ty = this.returnTypeStack[this.returnTypeStack.length-1]; + assert(ty !== null); + return ty; + } + + private getTypeEnv(): TypeEnv { + return this.typeEnvStack[this.typeEnvStack.length-1]; + } + + private createTRegularVar(node: Syntax | null = null): TRegularVar { + const typeVar = new TRegularVar(this.nextTypeVarId++, node); + this.getPolyContext().typeVars.add(typeVar); return typeVar; } private createRigidVar(displayName: string, node: Syntax | null = null): TRigidVar { - const tv = new TRigidVar(this.typeVarIds.next().value!, displayName, node); - this.getContext().typeVars.add(tv); + const tv = new TRigidVar(this.nextTypeVarId++, displayName, node); + this.getPolyContext().typeVars.add(tv); return tv; } - public getContext(): InferContext { - return this.contexts[this.contexts.length-1]; - } - private addConstraint(constraint: Constraint): void { + switch (constraint.kind) { + case ConstraintKind.Empty: break; + case ConstraintKind.Many: for (const element of constraint.elements) { this.addConstraint(element); } break; + case ConstraintKind.Equal: { const global = 0; + let maxLevelLeft = global; - for (let i = this.contexts.length; i-- > 0;) { - const ctx = this.contexts[i]; + for (let i = this.polyContextStack.length; i-- > 0;) { + const ctx = this.polyContextStack[i]; if (hasTypeVar(ctx.typeVars, constraint.left)) { maxLevelLeft = i; break; } } + let maxLevelRight = global; - for (let i = this.contexts.length; i-- > 0;) { - const ctx = this.contexts[i]; + for (let i = this.polyContextStack.length; i-- > 0;) { + const ctx = this.polyContextStack[i]; if (hasTypeVar(ctx.typeVars, constraint.right)) { maxLevelRight = i; break; } } + const upperLevel = Math.max(maxLevelLeft, maxLevelRight); let lowerLevel = upperLevel; - for (let i = 0; i < this.contexts.length; i++) { - const ctx = this.contexts[i]; + for (let i = 0; i < this.polyContextStack.length; i++) { + const ctx = this.polyContextStack[i]; if (hasTypeVar(ctx.typeVars, constraint.left) || hasTypeVar(ctx.typeVars, constraint.right)) { lowerLevel = i; break; } } + if (upperLevel == lowerLevel || maxLevelLeft == global || maxLevelRight == global) { this.solve(constraint); } else { - this.contexts[upperLevel].constraints.push(constraint); + this.polyContextStack[upperLevel].constraints.push(constraint); } + break; } + } - } - private pushContext(context: InferContext) { - this.contexts.push(context); - } - - private popContext(context: InferContext) { - assert(this.contexts[this.contexts.length-1] === context); - this.contexts.pop(); } private generalize(type: Type, constraints: Constraint[], env: TypeEnv): Scheme { @@ -544,7 +594,8 @@ export class Checker { maxIndex = Math.max(maxIndex, i+1); currDown = nextDown; } - const found = currDown.kindEnv!.get(name.text); + const currDownInfo = this.getInfo(currDown); + const found = currDownInfo.kindEnv!.get(name.text); if (found !== null) { return found; } @@ -582,11 +633,16 @@ export class Checker { } private lookup(node: NodeWithReference, expectedKind: Symkind, enableDiagnostics = true): Scheme | null { + const [modulePath, name] = splitReferences(node); + if (modulePath.length > 0) { + let maxIndex = 0; let currUp = node.getEnclosingModule(); + outer: for (;;) { + let currDown = currUp; for (let i = 0; i < modulePath.length; i++) { const moduleName = modulePath[i]; @@ -609,10 +665,14 @@ export class Checker { maxIndex = Math.max(maxIndex, i+1); currDown = nextDown; } - const found = currDown.typeEnv!.get(name.text, expectedKind); + + const currDownInfo = this.getInfo(currDown); + + const found = currDownInfo.typeEnv!.get(name.text, expectedKind); if (found !== null) { return found; } + if (enableDiagnostics) { this.diagnostics.add( new BindingNotFoundDiagnostic( @@ -622,17 +682,20 @@ export class Checker { ) ); } + return null; } + } else { - let curr: TypeEnv | null = this.getContext().env; - do { + + for (let i = this.typeEnvStack.length-1; i >= 0; i--) { + const curr = this.typeEnvStack[i]; const found = curr.get(name.text, expectedKind); if (found !== null) { return found; } - curr = curr.parent; - } while(curr !== null); + } + if (enableDiagnostics) { this.diagnostics.add( new BindingNotFoundDiagnostic( @@ -642,25 +705,18 @@ export class Checker { ) ); } + return null; + } - } - private getReturnType(): Type { - const context = this.getContext(); - assert(context.returnType !== null); - return context.returnType; - } - - private getTypeEnv(): TypeEnv { - return this.getContext().env; } private createSubstitution(scheme: Scheme): TVSub { const sub = new TVSub(); const tvs = [...scheme.typeVars] for (const tv of tvs) { - sub.set(tv, this.createTypeVar()); + sub.set(tv, this.createTRegularVar()); } return sub; } @@ -739,7 +795,7 @@ export class Checker { } private addBinding(name: string, scheme: Scheme, kind: Symkind): void { - this.getContext().env.add(name, scheme, kind); + this.getTypeEnv().add(name, scheme, kind); } private unifyKindMany(first: Kind, rest: Kind[], node: TypeExpression): boolean { @@ -752,6 +808,9 @@ export class Checker { // any errors and wish to proceed with type inference on this node. let kind: Kind | undefined; + // Fetch the type checking information for this node because we're going to use it anyways. + const info = this.getInfo(node); + switch (node.kind) { case SyntaxKind.TupleTypeExpression: @@ -832,7 +891,7 @@ export class Checker { // and this way the kind can be refrieved very efficiently. // Note that at this point `kind` may be undefined. This signals further // inference logic that this node should be skipped because it already contains errors. - node.inferredKind = kind; + info.inferredKind = kind; // Set a filler default for the node in a way that allows other unification // errors to be caught. @@ -887,7 +946,8 @@ export class Checker { switch (node.kind) { case SyntaxKind.ModuleDeclaration: { - const innerEnv = node.kindEnv = new KindEnv(env); + const info = this.getInfo(node); + const innerEnv = info.kindEnv = new KindEnv(env); for (const element of node.elements) { this.forwardDeclareKind(element, innerEnv); } @@ -937,7 +997,8 @@ export class Checker { case SyntaxKind.ModuleDeclaration: { - const innerEnv = node.kindEnv!; + const info = this.getInfo(node); + const innerEnv = info.kindEnv!; for (const element of node.elements) { this.inferKind(element, innerEnv); } @@ -1161,14 +1222,16 @@ export class Checker { case SyntaxKind.LetDeclaration: { + const info = this.getInfo(node); + if (isFunctionDeclarationLike(node)) { node.activeCycle = true; node.visited = true; - const context = node.context!; - const returnType = context.returnType!; - this.pushContext(context); + this.pushInfo(info); + + const returnType = info.returnType!; if (node.body !== null) { switch (node.body.kind) { @@ -1193,22 +1256,21 @@ export class Checker { } } - this.popContext(context); + this.popInfo(info); + node.activeCycle = false; } else { - const ctx = this.getContext(); - const constraints = new ConstraintSet; - const innerCtx: InferContext = { - ...ctx, - constraints, - }; - this.pushContext(innerCtx); + // const constraints = new ConstraintSet; + // this.polyContextStack.push(new PolyContext(parentPoly.typeVars, constraints)); + let type; + if (node.typeAssert !== null) { type = this.inferTypeExpression(node.typeAssert.typeExpression); } + if (node.body !== null) { let bodyType; switch (node.body.kind) { @@ -1226,7 +1288,7 @@ export class Checker { if (type === undefined) { type = bodyType; } else { - constraints.push( + this.addConstraint( new CEqual( type, bodyType, @@ -1235,12 +1297,16 @@ export class Checker { ); } } + if (type === undefined) { - type = this.createTypeVar(); + type = this.createTRegularVar(); } - this.popContext(innerCtx); - this.inferBindings(node.pattern, type, undefined, constraints, true); + + // this.polyContextStack.pop(); + + this.inferBindings(node.pattern, type, undefined, undefined, true); } + break; } @@ -1264,6 +1330,9 @@ export class Checker { } } + // We're going to use this eventually so might as well fetch it now + const info = this.getInfo(node); + let type: Type; switch (node.kind) { @@ -1275,24 +1344,24 @@ export class Checker { case SyntaxKind.MatchExpression: { let exprType; + if (node.expression !== null) { exprType = this.inferExpression(node.expression); } else { - exprType = this.createTypeVar(); + exprType = this.createTRegularVar(); } - type = this.createTypeVar(); + + type = this.createTRegularVar(); + for (const arm of node.arms) { - const context = this.getContext(); - const newEnv = new TypeEnv(context.env); - const newContext: InferContext = { - constraints: context.constraints, - typeVars: context.typeVars, - env: newEnv, - returnType: context.returnType, - }; - this.pushContext(newContext); - const armPatternType = this.createTypeVar(); + + const newEnv = new TypeEnv(); + this.typeEnvStack.push(newEnv); + + const armPatternType = this.createTRegularVar(); + this.inferBindings(arm.pattern, armPatternType); + this.addConstraint( new CEqual( armPatternType, @@ -1300,6 +1369,7 @@ export class Checker { arm.pattern, ) ); + this.addConstraint( new CEqual( type, @@ -1307,11 +1377,14 @@ export class Checker { arm.expression ) ); - this.popContext(newContext); + + this.typeEnvStack.pop(); } + if (node.expression === null) { type = new TArrow(exprType, type); } + break; } @@ -1325,7 +1398,7 @@ export class Checker { const target = scope.lookup(node.name.text); if (target !== null && target.kind === SyntaxKind.LetDeclaration) { if (target.activeCycle) { - return target.inferredType!; + return this.getInfo(target).inferredType!; } if (!target.visited) { this.infer(target); @@ -1334,7 +1407,7 @@ export class Checker { const scheme = this.lookup(node, Symkind.Var); if (scheme === null) { //this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); - type = this.createTypeVar(); + type = this.createTRegularVar(); break; } type = this.instantiate(scheme, node); @@ -1357,8 +1430,8 @@ export class Checker { default: assertNever(name); } - const newFieldType = this.createTypeVar(name); - const newRestType = this.createTypeVar(); + const newFieldType = this.createTRegularVar(name); + const newRestType = this.createTRegularVar(); this.addConstraint( new CEqual( type, @@ -1374,7 +1447,7 @@ export class Checker { case SyntaxKind.CallExpression: { const opType = this.inferExpression(node.func); - type = this.createTypeVar(node); + type = this.createTRegularVar(node); const paramTypes = []; for (const arg of node.args) { paramTypes.push(this.inferExpression(arg)); @@ -1421,7 +1494,7 @@ export class Checker { let fieldType; if (scheme === null) { // this.diagnostics.add(new BindingNotFoudDiagnostic(member.name.text, member.name)); - fieldType = this.createTypeVar(); + fieldType = this.createTRegularVar(); } else { fieldType = this.instantiate(scheme, member); } @@ -1441,12 +1514,12 @@ export class Checker { const scheme = this.lookup(node.operator, Symkind.Var); if (scheme === null) { // this.diagnostics.add(new BindingNotFoudDiagnostic(node.operator.text, node.operator)); - return this.createTypeVar(); + return this.createTRegularVar(); } const opType = this.instantiate(scheme, node.operator); const leftType = this.inferExpression(node.left); const rightType = this.inferExpression(node.right); - type = this.createTypeVar(); + type = this.createTRegularVar(); this.addConstraint( new CEqual( new TArrow(leftType, new TArrow(rightType, type)), @@ -1462,7 +1535,7 @@ export class Checker { } - node.inferredType = type; + info.inferredType = type; return type; @@ -1472,9 +1545,11 @@ export class Checker { let type; - if (checkKind && !node.inferredKind) { + const info = this.getInfo(node); - type = this.createTypeVar(); + if (checkKind && info.inferredKind === undefined) { + + type = this.createTRegularVar(); } else { @@ -1485,7 +1560,7 @@ export class Checker { const scheme = this.lookup(node, Symkind.Type); if (scheme === null) { // this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name)); - type = this.createTypeVar(); + type = this.createTRegularVar(); break; } type = this.instantiate(scheme, node.name); @@ -1549,12 +1624,13 @@ export class Checker { case SyntaxKind.ForallTypeExpression: { - const ctx = this.getContext(); + const env = this.getTypeEnv(); + const poly = this.getPolyContext(); // FIXME this is an ugly hack that doesn't even work. Either disallow Forall in this method or create a new TForall for (const varExpr of node.varTypeExps) { - const tv = this.createTypeVar(); - ctx.env.add(varExpr.name.text, Forall.mono(tv), Symkind.Type); - ctx.typeVars.add(tv); + const tv = this.createTRegularVar(); + env.add(varExpr.name.text, Forall.mono(tv), Symkind.Type); + poly.typeVars.add(tv); } return this.inferTypeExpression(node.typeExpr, introduceTypeVars); } @@ -1577,7 +1653,7 @@ export class Checker { } - node.inferredType = type; + info.inferredType = type; return type; @@ -1601,7 +1677,7 @@ export class Checker { } case SyntaxKind.NestedPattern: - this.inferBindings(pattern.pattern, type, typeVars, constraints); + this.inferBindings(pattern.pattern, type, typeVars, constraints, generalize); break; case SyntaxKind.NamedTuplePattern: @@ -1613,7 +1689,7 @@ export class Checker { const ctorType = this.instantiate(scheme, pattern); let elementTypes = []; for (const element of pattern.elements) { - const tv = this.createTypeVar(); + const tv = this.createTRegularVar(); this.inferBindings(element, tv, typeVars, constraints, generalize); elementTypes.push(tv); } @@ -1646,8 +1722,8 @@ export class Checker { case SyntaxKind.DisjunctivePattern: { - this.inferBindings(pattern.left, type, typeVars, constraints), - this.inferBindings(pattern.right, type, typeVars, constraints); + this.inferBindings(pattern.left, type, typeVars, constraints, generalize), + this.inferBindings(pattern.right, type, typeVars, constraints, generalize); break; } @@ -1659,23 +1735,23 @@ export class Checker { if (variadicMember === null) { restType = new TNil(pattern); } else { - restType = this.createTypeVar(); + restType = this.createTRegularVar(); if (variadicMember.pattern !== null) { - this.inferBindings(variadicMember.pattern, restType, typeVars, constraints); + this.inferBindings(variadicMember.pattern, restType, typeVars, constraints, generalize); } } for (const member of pattern.members) { switch (member.kind) { case SyntaxKind.StructPatternField: { - const fieldType = this.createTypeVar(); - this.inferBindings(member.pattern, fieldType, typeVars, constraints); + const fieldType = this.createTRegularVar(); + this.inferBindings(member.pattern, fieldType, typeVars, constraints, generalize); fields.set(member.name.text, fieldType); break; } case SyntaxKind.PunnedStructPatternField: { - const fieldType = this.createTypeVar(); + const fieldType = this.createTRegularVar(); this.addBinding(member.name.text, Forall.mono(fieldType), Symkind.Var); fields.set(member.name.text, fieldType); break; @@ -1703,42 +1779,52 @@ export class Checker { } - private initialize(node: Syntax, parentEnv: TypeEnv): void { + private initialize(node: Syntax): void { switch (node.kind) { case SyntaxKind.SourceFile: + { + const info = this.getInfo(node); + const poly = info.poly = new PolyContext(); + const returnType = info.returnType = null; + const env = info.typeEnv = new TypeEnv(); + + this.polyContextStack.push(poly); + this.typeEnvStack.push(env); + this.returnTypeStack.push(returnType); + + for (const element of node.elements) { + this.initialize(element); + } + + this.polyContextStack.pop(); + this.typeEnvStack.pop(); + this.returnTypeStack.pop(); + + break; + } + case SyntaxKind.ModuleDeclaration: { - const env = node.typeEnv = new TypeEnv(parentEnv); + const info = this.getInfo(node); + info.typeEnv = new TypeEnv(); for (const element of node.elements) { - this.initialize(element, env); + this.initialize(element); } break; } case SyntaxKind.ClassDeclaration: { - const other = this.classDecls.get(node.name.text); - if (other !== undefined) { - this.diagnostics.add(new TypeclassDeclaredTwiceDiagnostic(node.name, other)); - } else { - if (node.constraintClause !== null) { - for (const constraint of node.constraintClause.constraints) { - if (!this.classDecls.has(constraint.name.text)) { - this.diagnostics.add(new TypeclassNotFoundDiagnostic(constraint.name.text, constraint.name)); - } - } - } - this.classDecls.set(node.name.text, node); - } - const env = node.typeEnv = new TypeEnv(parentEnv); + const info = this.getInfo(node); + const env = info.typeEnv = new TypeEnv(); for (const tv of node.types) { assert(tv.kind === SyntaxKind.VarTypeExpression); - env.add(tv.name.text, Forall.mono(this.createTypeVar(tv)), Symkind.Type); + env.add(tv.name.text, Forall.mono(this.createTRegularVar(tv)), Symkind.Type); } for (const element of node.elements) { - this.initialize(element, env); + this.initialize(element); } break; } @@ -1748,19 +1834,22 @@ export class Checker { if (!this.classDecls.has(node.name.text)) { this.diagnostics.add(new TypeclassNotFoundDiagnostic(node.name.text, node.name)); } - const env = node.typeEnv = new TypeEnv(parentEnv); + const info = this.getInfo(node); + info.typeEnv = new TypeEnv(); for (const element of node.elements) { - this.initialize(element, env); + this.initialize(element); } break; } case SyntaxKind.LetDeclaration: { - const env = node.typeEnv = new TypeEnv(parentEnv); + const info = this.getInfo(node); + info.typeEnv = new TypeEnv(); + // The rest of the info properties are set in Checker.check() if (node.body !== null && node.body.kind === SyntaxKind.BlockBody) { for (const element of node.body.elements) { - this.initialize(element, env); + this.initialize(element); } } break; @@ -1773,34 +1862,35 @@ export class Checker { case SyntaxKind.EnumDeclaration: { - const env = node.typeEnv = new TypeEnv(parentEnv); - const constraints = new ConstraintSet(); - const typeVars = new TVSet(); - const context: InferContext = { - typeVars, - env, - constraints, - returnType: null, - } + const info = this.getInfo(node); + const env = info.typeEnv = new TypeEnv();; + const poly = info.poly = new PolyContext(); + const parentEnv = this.getTypeEnv(); - this.pushContext(context); + this.typeEnvStack.push(env); + this.polyContextStack.push(poly); - const kindArgs = []; + const typeArgs = []; for (const name of node.varExps) { - const kindArg = this.createTypeVar(); - env.add(name.text, Forall.mono(kindArg), Symkind.Type); - kindArgs.push(kindArg); + const typeArg = this.createTRegularVar(); + env.add(name.text, Forall.mono(typeArg), Symkind.Type); + typeArgs.push(typeArg); } const type = this.createTCon(node.name.text, node); - const appliedType = TApp.build(type, kindArgs); - parentEnv.add(node.name.text, new Forall(typeVars, new CMany(constraints), type), Symkind.Type); + const appliedType = TApp.build(type, typeArgs); + parentEnv.add(node.name.text, new Forall(poly.typeVars, new CMany(poly.constraints), type), Symkind.Type); let elementTypes: Type[] = []; + if (node.members !== null) { + for (const member of node.members) { + let ctorType, elementType; + switch (member.kind) { + case SyntaxKind.EnumDeclarationTupleElement: { const args: Array<[Syntax, Type]> = member.elements.map(el => [el, this.inferTypeExpression(el, false)]); @@ -1808,6 +1898,7 @@ export class Checker { ctorType = TArrow.build(args.map(a => a[1]), appliedType, member); break; } + case SyntaxKind.EnumDeclarationStructElement: { const restType = new TNil(member); @@ -1819,74 +1910,89 @@ export class Checker { ctorType = new TArrow(elementType, appliedType, member); break; } + default: throw new Error(`Unexpected ${member}`); + } - parentEnv.add(member.name.text, new Forall(typeVars, new CMany(constraints), ctorType), Symkind.Var); + + parentEnv.add(member.name.text, new Forall(poly.typeVars, new CMany(poly.constraints), ctorType), Symkind.Var); elementTypes.push(elementType); } + } - this.popContext(context); + this.polyContextStack.pop(); + this.typeEnvStack.pop(); break; } case SyntaxKind.TypeDeclaration: { - const env = node.typeEnv = new TypeEnv(parentEnv); - const constraints = new ConstraintSet(); - const typeVars = new TVSet(); - const context: InferContext = { - constraints, - typeVars, - env, - returnType: null, - }; - this.pushContext(context); - const kindArgs = []; + const info = this.getInfo(node); + const parentEnv = this.getTypeEnv(); + const env = info.typeEnv = new TypeEnv();; + const poly = info.poly = new PolyContext(); + + this.polyContextStack.push(poly); + this.typeEnvStack.push(env); + + const typeArgs = []; + for (const varExpr of node.varExps) { - const typeVar = this.createTypeVar(); - kindArgs.push(typeVar); + const typeVar = this.createTRegularVar(); + typeArgs.push(typeVar); env.add(varExpr.text, Forall.mono(typeVar), Symkind.Type); } + const type = this.inferTypeExpression(node.typeExpression); - this.popContext(context); - const scheme = new Forall(typeVars, new CMany(constraints), TApp.build(type, kindArgs)); + + this.polyContextStack.pop(); + this.typeEnvStack.pop(); + + const scheme = new Forall(poly.typeVars, new CMany(poly.constraints), TApp.build(type, typeArgs)); + parentEnv.add(node.name.text, scheme, Symkind.Type); + break; } case SyntaxKind.StructDeclaration: { - const env = node.typeEnv = new TypeEnv(parentEnv); - const typeVars = new TVSet(); - const constraints = new ConstraintSet(); - const context: InferContext = { - constraints, - typeVars, - env, - returnType: null, - }; - this.pushContext(context); - const kindArgs = []; + const info = this.getInfo(node); + const parentEnv = this.getTypeEnv(); + const env = info.typeEnv = new TypeEnv(); + const poly = info.poly = new PolyContext(); + + this.polyContextStack.push(poly); + this.typeEnvStack.push(env); + + const typeArgs = []; for (const varExpr of node.varExps) { - const kindArg = this.createTypeVar(); - env.add(varExpr.text, Forall.mono(kindArg), Symkind.Type); - kindArgs.push(kindArg); + const typeArg = this.createTRegularVar(); + env.add(varExpr.text, Forall.mono(typeArg), Symkind.Type); + typeArgs.push(typeArg); } + const fields = new Map(); const restType = new TNil(node); + if (node.fields !== null) { for (const field of node.fields) { fields.set(field.name.text, this.inferTypeExpression(field.typeExpr)); } } + const type = this.createTCon(node.name.text, node.name); const recordType = TField.build(fields, restType); - this.popContext(context); - parentEnv.add(node.name.text, new Forall(typeVars, new CMany(constraints), type), Symkind.Type); - parentEnv.add(node.name.text, new Forall(typeVars, new CMany(constraints), new TArrow(recordType, type)), Symkind.Var); + + this.polyContextStack.pop(); + this.typeEnvStack.pop(); + + parentEnv.add(node.name.text, new Forall(poly.typeVars, new CMany(poly.constraints), type), Symkind.Type); + parentEnv.add(node.name.text, new Forall(poly.typeVars, new CMany(poly.constraints), new TArrow(recordType, type)), Symkind.Var); + break; } @@ -1899,29 +2005,27 @@ export class Checker { public check(sourceFile: SourceFile): void { - const kenv = new KindEnv(this.globalKindEnv); - this.forwardDeclareKind(sourceFile, kenv); - this.inferKind(sourceFile, kenv); + // Kind inference + const kindEnv = new KindEnv(this.globalKindEnv); + this.forwardDeclareKind(sourceFile, kindEnv); + this.inferKind(sourceFile, kindEnv); - this.initialize(sourceFile, this.globalTypeEnv); + // Type inference - const typeVars = new TVSet(); - const constraints = new ConstraintSet(); - const sourceFileCtx = { - typeVars, - constraints, - env: sourceFile.typeEnv!, - returnType: null - }; + this.typeEnvStack.push(this.globalTypeEnv); - this.pushContext(sourceFileCtx); + this.initialize(sourceFile); + + const sourceFileInfo = this.getInfo(sourceFile); + this.pushInfo(sourceFileInfo); const sccs = [...this.analyser.getSortedDeclarations()]; for (const nodes of sccs) { - const typeVars = new TVSet(); - const constraints = new ConstraintSet(); + const poly = new PolyContext(); + + this.polyContextStack.push(poly); for (const node of nodes) { @@ -1929,22 +2033,15 @@ export class Checker { continue; } - const env = node.typeEnv!; - const innerCtx: InferContext = { - typeVars, - constraints, - env, - returnType: null, - }; - node.context = innerCtx; + const info = this.getInfo(node); + info.poly = poly; + const returnType = info.returnType = this.createTRegularVar(); - this.pushContext(innerCtx); - - const returnType = this.createTypeVar(); - innerCtx.returnType = returnType; + this.typeEnvStack.push(info.typeEnv!); + this.returnTypeStack.push(info.returnType!); const paramTypes = node.params.map(param => { - const paramType = this.createTypeVar(); + const paramType = this.createTRegularVar(); this.inferBindings(param.pattern, paramType) return paramType; }); @@ -1960,7 +2057,8 @@ export class Checker { ) ); } - node.inferredType = type; + + info.inferredType = type; // if (node.parent!.kind === SyntaxKind.InstanceDeclaration) { // const inst = node.parent!; @@ -1974,29 +2072,26 @@ export class Checker { // this.addConstraint(new CEqual(type, other.inferredType!, node)); // } - this.popContext(innerCtx); + this.returnTypeStack.pop(); + this.typeEnvStack.pop(); if (node.parent!.kind !== SyntaxKind.InstanceDeclaration) { - const scopeDecl = node.parent!.getScope().node; - const outer = { - typeVars: innerCtx.typeVars, - constraints: innerCtx.constraints, - env: scopeDecl.typeEnv!, - returnType: null, - }; - this.contexts.push(outer) - this.inferBindings(node.pattern, type, typeVars, constraints); - this.contexts.pop(); + this.inferBindings(node.pattern, type, poly.typeVars, poly.constraints); } + } + this.polyContextStack.pop(); + } this.infer(sourceFile); - this.popContext(sourceFileCtx); + // Pop off whatever we pushed in during initialization + this.popInfo(sourceFileInfo); + this.typeEnvStack.pop(); - this.solve(new CMany(constraints)); + this.solve(new CMany(sourceFileInfo.poly!.constraints)); } @@ -2005,8 +2100,8 @@ export class Checker { private maxTypeErrorCount = 5; private find(type: Type): Type { - while (type.kind === TypeKind.UniVar && this.solution.has(type)) { - type = this.solution.get(type)!; + while (type.kind === TypeKind.UniVar && this.typeSolution.has(type)) { + type = this.typeSolution.get(type)!; } return type; } @@ -2041,7 +2136,8 @@ export class Checker { private unify(left: Type, right: Type, enableDiagnostics: boolean): boolean { - // console.log(`unify ${describeType(left)} @ ${left.node && left.node.constructor && left.node.constructor.name} ~ ${describeType(right)} @ ${right.node && right.node.constructor && right.node.constructor.name}`); + //console.log(`unify ${describeType(left)} @ ${left.node && left.node.constructor && left.node.constructor.name} ~ ${describeType(right)} @ ${right.node && right.node.constructor && right.node.constructor.name}`); + //console.log(`unify ${describeType(left)} ~ ${describeType(right)}`); left = this.simplifyType(left); right = this.simplifyType(right); @@ -2111,8 +2207,8 @@ export class Checker { // Should it get assigned another unification variable, that's OK too // because then that variable is what matters and it will become the new // (possibly polymorphic) variable. - if (this.contexts.length > 0) { - this.contexts[this.contexts.length-1].typeVars.delete(left); + if (this.polyContextStack.length > 0) { + this.polyContextStack[this.polyContextStack.length-1].typeVars.delete(left); } // These types will be join, and we'd like to track that @@ -2165,7 +2261,7 @@ export class Checker { return success; } let success = true; - const newRestType = new TUniVar(this.typeVarIds.next().value!); + const newRestType = new TRegularVar(this.nextTypeVarId++); if (!this.unify(left.restType, new TField(right.name, right.type, newRestType), enableDiagnostics)) { success = false; } @@ -2253,8 +2349,9 @@ export class Checker { } public getTypeOfNode(node: Syntax): Type { - assert(node.inferredType !== undefined); - return this.simplifyType(node.inferredType); + const info = this.getInfo(node); + assert(info.inferredType !== undefined); + return this.simplifyType(info.inferredType); } // private *findInstanceContext(type: TCon, clazz: ClassDeclaration): Iterable { diff --git a/compiler/src/cst.ts b/compiler/src/cst.ts index 34e8a73c8..a67fc8ca5 100644 --- a/compiler/src/cst.ts +++ b/compiler/src/cst.ts @@ -4,7 +4,7 @@ import path from "path" import { assert, implementationLimitation, IndentWriter, JSONObject, JSONValue, nonenumerable, unreachable } from "./util"; import { isNodeWithScope, Scope } from "./scope" -import type { InferContext, Kind, KindEnv, Scheme, TypeEnv } from "./checker" +import type { Kind, Scheme } from "./checker" import type { Type } from "./types"; import { Emitter } from "./emitter"; @@ -253,11 +253,6 @@ abstract class SyntaxBase { @nonenumerable public parent: Syntax | null = null; - @nonenumerable - public inferredKind?: Kind; - @nonenumerable - public inferredType?: Type; - public abstract getFirstToken(): Token; public abstract getLastToken(): Token; @@ -296,6 +291,10 @@ abstract class SyntaxBase { throw new Error(`Could not find a scope for ${this}. Maybe the parent links are not set?`); } + public getParentScope(): Scope | null { + return this.parent === null ? null : this.parent.getScope(); + } + public getEnclosingModule(): ModuleDeclaration | SourceFile { let curr = this.parent!; while (curr !== null) { @@ -2595,9 +2594,6 @@ export class EnumDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.EnumDeclaration; - @nonenumerable - public typeEnv?: TypeEnv; - public constructor( public pubKeyword: PubKeyword | null, public enumKeyword: EnumKeyword, @@ -2668,9 +2664,6 @@ export class StructDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.StructDeclaration; - @nonenumerable - public typeEnv?: TypeEnv; - public constructor( public pubKeyword: PubKeyword | null, public structKeyword: StructKeyword, @@ -2836,9 +2829,6 @@ export class TypeDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.TypeDeclaration; - @nonenumerable - public typeEnv?: TypeEnv; - public constructor( public pubKeyword: PubKeyword | null, public typeKeyword: TypeKeyword, @@ -2881,17 +2871,11 @@ export class LetDeclaration extends SyntaxBase { @nonenumerable public scope?: Scope; - @nonenumerable - public typeEnv?: TypeEnv; - @nonenumerable public activeCycle?: boolean; @nonenumerable public visited?: boolean; - @nonenumerable - public context?: InferContext; - public constructor( public pubKeyword: PubKeyword | null, public letKeyword: LetKeyword, @@ -3082,9 +3066,6 @@ export class ClassDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.ClassDeclaration; - @nonenumerable - public typeEnv?: TypeEnv; - public constructor( public pubKeyword: PubKeyword | null, public classKeyword: ClassKeyword, @@ -3186,9 +3167,6 @@ export class InstanceDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.InstanceDeclaration; - @nonenumerable - public typeEnv?: TypeEnv; - public constructor( public pubKeyword: PubKeyword | null, public classKeyword: InstanceKeyword, @@ -3233,11 +3211,6 @@ export class ModuleDeclaration extends SyntaxBase { public readonly kind = SyntaxKind.ModuleDeclaration; - @nonenumerable - public typeEnv?: TypeEnv; - @nonenumerable - public kindEnv?: KindEnv; - public constructor( public pubKeyword: PubKeyword | null, public modKeyword: ModKeyword, @@ -3287,10 +3260,6 @@ export class SourceFile extends SyntaxBase { @nonenumerable public scope?: Scope; - @nonenumerable - public typeEnv?: TypeEnv; - @nonenumerable - public kindEnv?: KindEnv; public constructor( private file: TextFile, diff --git a/compiler/src/types.ts b/compiler/src/types.ts index 3305abeb7..c96dd0563 100644 --- a/compiler/src/types.ts +++ b/compiler/src/types.ts @@ -50,7 +50,7 @@ export abstract class TypeBase { this.find().parent = newType; } - public hasTypeVar(tv: TUniVar): boolean { + public hasTypeVar(tv: TRegularVar): boolean { for (const other of this.getTypeVars()) { if (tv.id === other.id) { return true; @@ -105,7 +105,7 @@ export class TRigidVar extends TypeBase { } -export class TUniVar extends TypeBase { +export class TRegularVar extends TypeBase { public readonly kind = TypeKind.UniVar; @@ -122,8 +122,8 @@ export class TUniVar extends TypeBase { yield this; } - public shallowClone(): TUniVar { - return new TUniVar(this.id, this.node); + public shallowClone(): TRegularVar { + return new TRegularVar(this.id, this.node); } public substitute(sub: TVSub): Type { @@ -451,7 +451,7 @@ export type Type = TCon | TArrow | TRigidVar - | TUniVar + | TRegularVar | TApp | TField | TNil @@ -459,10 +459,9 @@ export type Type | TAbsent export type TVar - = TUniVar + = TRegularVar | TRigidVar - export function typesEqual(a: Type, b: Type): boolean { if (a.kind !== b.kind) { return false;