Multiple updates to the type-checker

- Add support for type declarations
- Make polymorphism in struct declarations work
This commit is contained in:
Sam Vervaeck 2022-09-11 11:20:21 +02:00
parent d12ffa1de5
commit 988215cdb3
4 changed files with 353 additions and 41 deletions

View file

@ -4,13 +4,21 @@ import {
Pattern,
Scope,
SourceFile,
SourceFileElement,
StructDeclaration,
Symkind,
Syntax,
SyntaxKind,
TypeExpression
} from "./cst";
import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, describeType, Diagnostics, FieldDoesNotExistDiagnostic, FieldMissingDiagnostic, UnificationFailedDiagnostic } from "./diagnostics";
import {
describeType,
ArityMismatchDiagnostic,
BindingNotFoudDiagnostic,
Diagnostics,
FieldDoesNotExistDiagnostic,
FieldMissingDiagnostic,
UnificationFailedDiagnostic,
KindMismatchDiagnostic
} from "./diagnostics";
import { assert, isEmpty } from "./util";
import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl"
@ -30,6 +38,7 @@ export enum TypeKind {
Tuple,
Labeled,
Record,
App,
}
abstract class TypeBase {
@ -272,7 +281,7 @@ export class TRecord extends TypeBase {
public nextRecord: TRecord | null = null;
public constructor(
public decl: StructDeclaration,
public decl: Syntax,
public fields: Map<string, Type>,
public node: Syntax | null = null,
) {
@ -308,6 +317,68 @@ export class TRecord extends TypeBase {
}
export class TApp extends TypeBase {
public readonly kind = TypeKind.App;
public constructor(
public operatorType: Type,
public argType: Type,
public node: Syntax | null = null
) {
super(node);
}
public static build(operatorType: Type, argTypes: Type[], node: Syntax | null = null): TApp {
let count = argTypes.length;
let result = argTypes[count-1];
for (let i = count-2; i >= 0; i--) {
result = new TApp(argTypes[i], result, node);
}
return new TApp(operatorType, result, node);
}
public *getSequence(): Iterable<Type> {
if (this.operatorType.kind === TypeKind.App) {
yield* this.operatorType.getSequence();
} else {
yield this.operatorType;
}
if (this.argType.kind === TypeKind.App) {
yield* this.argType.getSequence();
} else {
yield this.argType;
}
}
public *getTypeVars(): Iterable<TVar> {
yield* this.operatorType.getTypeVars();
yield* this.argType.getTypeVars();
}
public shallowClone() {
return new TApp(
this.operatorType,
this.argType,
this.node
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newOperatorType = this.operatorType.substitute(sub);
if (newOperatorType !== this.operatorType) {
changed = true;
}
const newArgType = this.argType.substitute(sub);
if (newArgType !== this.argType) {
changed = true;
}
return changed ? new TApp(newOperatorType, newArgType, this.node) : this;
}
}
export type Type
= TCon
| TArrow
@ -315,6 +386,7 @@ export type Type
| TTuple
| TLabeled
| TRecord
| TApp
class TVSet {
@ -597,11 +669,15 @@ export class Checker {
return context.returnType;
}
private instantiate(scheme: Scheme, node: Syntax | null): Type {
private createSubstitution(scheme: Scheme): TVSub {
const sub = new TVSub();
for (const tv of scheme.typeVars) {
sub.set(tv, this.createTypeVar());
}
return sub;
}
private instantiate(scheme: Scheme, node: Syntax | null, sub = this.createSubstitution(scheme)): Type {
for (const constraint of scheme.constraints) {
const substituted = constraint.substitute(sub);
substituted.node = node;
@ -716,6 +792,8 @@ export class Checker {
break;
}
case SyntaxKind.TypeDeclaration:
case SyntaxKind.EnumDeclaration:
case SyntaxKind.StructDeclaration:
break;
@ -820,13 +898,20 @@ export class Checker {
case SyntaxKind.StructExpression:
{
const scheme = this.lookup(node.name.text);
if (scheme === null) {
const scope = node.getScope();
const decl = scope.lookup(node.name.text, Symkind.Constructor);
if (decl === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
}
const recordType = this.instantiate(scheme, node);
assert(recordType.kind === TypeKind.Record);
assert(decl.kind === SyntaxKind.StructDeclaration || decl.kind === SyntaxKind.EnumDeclarationStructElement);
const scheme = decl.scheme;
const sub = this.createSubstitution(scheme);
const declType = this.instantiate(scheme, node, sub);
const argTypes = [];
for (const typeVar of decl.tvs) {
argTypes.push(sub.get(typeVar)!);
}
const fields = new Map();
for (const member of node.members) {
switch (member.kind) {
@ -852,10 +937,10 @@ export class Checker {
throw new Error(`Unexpected ${member}`);
}
}
const type = new TRecord(recordType.decl, fields, node);
const type = TApp.build(new TRecord(decl, fields, node), argTypes, node);
this.addConstraint(
new CEqual(
recordType,
TApp.build(declType, argTypes, node),
type,
node,
)
@ -923,6 +1008,16 @@ export class Checker {
return scheme.type;
}
case SyntaxKind.AppTypeExpression:
{
const operatorType = this.inferTypeExpression(node.operator);
const argTypes = [];
for (const argTypeExpr of node.args) {
argTypes.push(this.inferTypeExpression(argTypeExpr));
}
return TApp.build(operatorType, argTypes);
}
case SyntaxKind.ArrowTypeExpression:
{
const paramTypes = [];
@ -1145,6 +1240,7 @@ export class Checker {
break;
}
case SyntaxKind.TypeDeclaration:
case SyntaxKind.EnumDeclaration:
case SyntaxKind.StructDeclaration:
break;
@ -1186,6 +1282,7 @@ export class Checker {
case SyntaxKind.IfStatement:
case SyntaxKind.ReturnStatement:
case SyntaxKind.ExpressionStatement:
case SyntaxKind.TypeDeclaration:
case SyntaxKind.EnumDeclaration:
case SyntaxKind.StructDeclaration:
break;
@ -1232,6 +1329,29 @@ export class Checker {
break;
}
case SyntaxKind.TypeDeclaration:
{
const env = node.typeEnv = new TypeEnv(parentEnv);
const constraints = new ConstraintSet();
const typeVars = new TVSet();
const context: InferContext = {
constraints,
typeVars,
env,
returnType: null,
};
this.pushContext(context);
for (const varExpr of node.typeVars) {
env.add(varExpr.text, new Forall([], [], this.createTypeVar()));
}
const type = this.inferTypeExpression(node.typeExpression);
this.popContext(context);
const scheme = new Forall(typeVars, constraints, type);
parentEnv.add(node.name.text, scheme);
node.scheme = scheme;
break;
}
case SyntaxKind.StructDeclaration:
{
const env = node.typeEnv = new TypeEnv(parentEnv);
@ -1244,8 +1364,11 @@ export class Checker {
returnType: null,
};
this.pushContext(context);
const argTypes = [];
for (const varExpr of node.typeVars) {
env.add(varExpr.text, new Forall([], [], this.createTypeVar()));
const type = this.createTypeVar();
env.add(varExpr.text, new Forall([], [], type));
argTypes.push(type);
}
const fields = new Map<string, Type>();
if (node.members !== null) {
@ -1254,8 +1377,11 @@ export class Checker {
}
}
this.popContext(context);
const type = new TRecord(node, fields);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type));
const type = new TRecord(node, fields, node);
const scheme = new Forall(typeVars, constraints, type);
parentEnv.add(node.name.text, scheme);
node.tvs = argTypes;
node.scheme = scheme; //new Forall(typeVars, constraints, new TApp(type, argTypes));
break;
}
@ -1561,6 +1687,23 @@ export class Checker {
}
}
if (left.kind === TypeKind.App && right.kind === TypeKind.App) {
let leftElements = [...left.getSequence()];
let rightElements = [...right.getSequence()];
if (leftElements.length !== rightElements.length) {
this.diagnostics.add(new KindMismatchDiagnostic(leftElements.length-1, rightElements.length-1, constraint.node));
return false;
}
const count = leftElements.length;
let success = true;
for (let i = 0; i < count; i++) {
if (!this.unify(leftElements[i], rightElements[i], solution, constraint)) {
success = false;
}
}
return success;
}
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) {
let success = false;
// This works like an ordinary union-find algorithm where an additional
@ -1604,12 +1747,12 @@ export class Checker {
}
remaining.delete(fieldName);
} else {
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName));
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName, constraint.node));
success = false;
}
}
for (const fieldName of remaining) {
this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName));
this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName, constraint.node));
}
if (success) {
TypeBase.join(left, right);

View file

@ -108,6 +108,7 @@ export const enum SyntaxKind {
ReferenceTypeExpression,
ArrowTypeExpression,
VarTypeExpression,
AppTypeExpression,
// Patterns
BindPattern,
@ -147,14 +148,11 @@ export const enum SyntaxKind {
IfStatementCase,
// Declarations
VariableDeclaration,
PrefixFuncDecl,
SuffixFuncDecl,
LetDeclaration,
StructDeclaration,
EnumDeclaration,
ImportDeclaration,
TypeAliasDeclaration,
TypeDeclaration,
// Let declaration body members
ExprBody,
@ -185,6 +183,7 @@ export type Syntax
| Param
| Body
| StructDeclarationField
| EnumDeclarationElement
| TypeAssert
| Declaration
| Statement
@ -207,9 +206,16 @@ function isNodeWithScope(node: Syntax): node is NodeWithScope {
|| node.kind === SyntaxKind.LetDeclaration;
}
export const enum Symkind {
Var = 1,
Type = 2,
Constructor = 4,
Any = Var | Type | Constructor
}
export class Scope {
private mapping = new Map<string, Syntax>();
private mapping = new Map<string, [Symkind, Syntax]>();
public constructor(
public node: NodeWithScope,
@ -228,6 +234,10 @@ export class Scope {
return null;
}
private add(name: string, node: Syntax, kind: Symkind): void {
this.mapping.set(name, [kind, node]);
}
private scan(node: Syntax): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
@ -241,9 +251,17 @@ export class Scope {
case SyntaxKind.ReturnStatement:
case SyntaxKind.IfStatement:
break;
case SyntaxKind.TypeDeclaration:
{
this.add(node.name.text, node, Symkind.Type);
break;
}
case SyntaxKind.EnumDeclaration:
case SyntaxKind.StructDeclaration:
{
this.add(node.name.text, node, Symkind.Constructor);
break;
}
case SyntaxKind.LetDeclaration:
{
for (const param of node.params) {
@ -257,7 +275,7 @@ export class Scope {
}
} else {
if (node.pattern.kind === SyntaxKind.WrappedOperator) {
this.mapping.set(node.pattern.operator.text, node);
this.add(node.pattern.operator.text, node, Symkind.Var);
} else {
this.scanPattern(node.pattern, node);
}
@ -273,7 +291,7 @@ export class Scope {
switch (node.kind) {
case SyntaxKind.BindPattern:
{
this.mapping.set(node.name.text, decl);
this.add(node.name.text, decl, Symkind.Var);
break;
}
case SyntaxKind.StructPattern:
@ -287,7 +305,7 @@ export class Scope {
}
case SyntaxKind.PunnedStructPatternField:
{
this.mapping.set(node.name.text, decl);
this.add(node.name.text, decl, Symkind.Var);
break;
}
}
@ -299,13 +317,16 @@ export class Scope {
}
}
public lookup(name: string): Syntax | null {
public lookup(name: string, expectedKind = Symkind.Any): Syntax | null {
let curr: Scope | null = this;
do {
const decl = curr.mapping.get(name);
if (decl !== undefined) {
const match = curr.mapping.get(name);
if (match !== undefined) {
const [kind, decl] = match;
if (kind & expectedKind) {
return decl;
}
}
curr = curr.getParent();
} while (curr !== null);
return null;
@ -967,6 +988,30 @@ export class ReferenceTypeExpression extends SyntaxBase {
}
export class AppTypeExpression extends SyntaxBase {
public readonly kind = SyntaxKind.AppTypeExpression;
public constructor(
public operator: TypeExpression,
public args: TypeExpression[],
) {
super();
}
public getFirstToken(): Token {
return this.operator.getFirstToken();
}
public getLastToken(): Token {
if (this.args.length > 0) {
return this.args[this.args.length-1].getLastToken();
}
return this.operator.getLastToken();
}
}
export class VarTypeExpression extends SyntaxBase {
public readonly kind = SyntaxKind.VarTypeExpression;
@ -991,6 +1036,7 @@ export type TypeExpression
= ReferenceTypeExpression
| ArrowTypeExpression
| VarTypeExpression
| AppTypeExpression
export class BindPattern extends SyntaxBase {
@ -1853,6 +1899,34 @@ export class WrappedOperator extends SyntaxBase {
}
export class TypeDeclaration extends SyntaxBase {
public readonly kind = SyntaxKind.TypeDeclaration;
public constructor(
public pubKeyword: PubKeyword | null,
public typeKeyword: TypeKeyword,
public name: IdentifierAlt,
public typeVars: Identifier[],
public equals: Equals,
public typeExpression: TypeExpression
) {
super();
}
public getFirstToken(): Token {
if (this.pubKeyword !== null) {
return this.pubKeyword;
}
return this.typeKeyword;
}
public getLastToken(): Token {
return this.typeExpression.getLastToken();
}
}
export class LetDeclaration extends SyntaxBase {
public readonly kind = SyntaxKind.LetDeclaration;
@ -1923,6 +1997,7 @@ export type Declaration
| ImportDeclaration
| StructDeclaration
| EnumDeclaration
| TypeDeclaration
export class Initializer extends SyntaxBase {

View file

@ -198,19 +198,25 @@ export function describeType(type: Type): string {
}
case TypeKind.Record:
{
let out = type.decl.name.text + ' { ';
let first = true;
for (const [fieldName, fieldType] of type.fields) {
if (first) first = false;
else out += ', ';
out += fieldName + ': ' + describeType(fieldType);
}
return out + ' }';
return type.decl.name.text;
// let out = type.decl.name.text + ' { ';
// let first = true;
// for (const [fieldName, fieldType] of type.fields) {
// if (first) first = false;
// else out += ', ';
// out += fieldName + ': ' + describeType(fieldType);
// }
// return out + ' }';
}
case TypeKind.Labeled:
{
// FIXME may need to include fields that were added during unification
return '{ ' + type.name + ': ' + describeType(type.type) + ' }';
}
case TypeKind.App:
{
return describeType(type.operatorType) + ' ' + describeType(type.argType);
}
}
}
@ -294,6 +300,7 @@ export class FieldMissingDiagnostic {
public constructor(
public recordType: TRecord,
public fieldName: string,
public node: Syntax | null,
) {
}
@ -302,8 +309,8 @@ export class FieldMissingDiagnostic {
out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET);
out.write(`field '${this.fieldName}' is missing from `);
out.write(describeType(this.recordType) + '\n\n');
if (this.recordType.node !== null) {
out.write(printNode(this.recordType.node) + '\n');
if (this.node !== null) {
out.write(printNode(this.node) + '\n');
}
}
@ -316,13 +323,48 @@ export class FieldDoesNotExistDiagnostic {
public constructor(
public recordType: TRecord,
public fieldName: string,
public node: Syntax | null,
) {
}
public format(out: IndentWriter): void {
out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET);
out.write(`field '${this.fieldName}' does not exist on type `);
out.write(describeType(this.recordType) + '\n\n');
if (this.node !== null) {
out.write(printNode(this.node) + '\n');
}
}
}
export class KindMismatchDiagnostic {
public readonly level = Level.Error;
public constructor(
public leftSize: number,
public rightSize: number,
public node: Syntax | null,
) {
}
public format(out: IndentWriter): void {
out.write(ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET);
out.write(`kind `);
for (let i = 0; i < this.leftSize-1; i++) {
out.write(`* -> `);
}
out.write(`* does not match with `);
for (let i = 0; i < this.rightSize-1; i++) {
out.write(`* -> `);
}
out.write(`*\n\n`);
if (this.node !== null) {
out.write(printNode(this.node) + '\n');
}
}
}
@ -335,6 +377,7 @@ export type Diagnostic
| ArityMismatchDiagnostic
| FieldMissingDiagnostic
| FieldDoesNotExistDiagnostic
| KindMismatchDiagnostic
export interface Diagnostics {
add(diagnostic: Diagnostic): void;

View file

@ -52,6 +52,8 @@ import {
EnumDeclaration,
EnumDeclarationTupleElement,
VarTypeExpression,
TypeDeclaration,
AppTypeExpression,
} from "./cst"
import { Stream } from "./util";
@ -181,8 +183,30 @@ export class Parser {
}
}
private tryParseAppTypeExpression(): TypeExpression {
const operator = this.parsePrimitiveTypeExpression();
const args = [];
for (;;) {
const t1 = this.peekToken();
if (t1.kind === SyntaxKind.RParen
|| t1.kind === SyntaxKind.RBrace
|| t1.kind === SyntaxKind.RBracket
|| t1.kind === SyntaxKind.Equals
|| t1.kind === SyntaxKind.BlockStart
|| t1.kind === SyntaxKind.LineFoldEnd
|| t1.kind === SyntaxKind.RArrow) {
break;
}
args.push(this.parsePrimitiveTypeExpression());
}
if (args.length === 0) {
return operator;
}
return new AppTypeExpression(operator, args);
}
public parseTypeExpression(): TypeExpression {
let returnType = this.parsePrimitiveTypeExpression();
let returnType = this.tryParseAppTypeExpression();
const paramTypes = [];
for (;;) {
const t1 = this.peekToken();
@ -191,7 +215,7 @@ export class Parser {
}
this.getToken();
paramTypes.push(returnType);
returnType = this.parsePrimitiveTypeExpression();
returnType = this.tryParseAppTypeExpression();
}
if (paramTypes.length === 0) {
return returnType;
@ -417,6 +441,31 @@ export class Parser {
return this.parseBinaryOperatorAfterExpr(lhs, 0);
}
public parseTypeDeclaration(): TypeDeclaration {
let pubKeyword = null;
let t0 = this.getToken();
if (t0.kind === SyntaxKind.PubKeyword) {
pubKeyword = t0;
t0 = this.getToken();
}
if (t0.kind !== SyntaxKind.TypeKeyword) {
this.raiseParseError(t0, [ SyntaxKind.TypeKeyword ]);
}
const name = this.expectToken(SyntaxKind.IdentifierAlt);
const typeVars = [];
let t1 = this.getToken();
while (t1.kind === SyntaxKind.Identifier) {
typeVars.push(t1);
t1 = this.getToken();
}
if (t1.kind !== SyntaxKind.Equals) {
this.raiseParseError(t1, [ SyntaxKind.Equals ]);
}
const typeExpr = this.parseTypeExpression();
this.expectToken(SyntaxKind.LineFoldEnd);
return new TypeDeclaration(pubKeyword, t0, name, typeVars, t1, typeExpr);
}
public parseEnumDeclaration(): EnumDeclaration {
let pubKeyword = null;
let t0 = this.getToken();
@ -817,7 +866,9 @@ export class Parser {
case SyntaxKind.StructKeyword:
return this.parseStructDeclaration();
case SyntaxKind.EnumKeyword:
return this.parseEnumDeclaration();
return this.parseStructDeclaration();
case SyntaxKind.TypeKeyword:
return this.parseTypeDeclaration();
case SyntaxKind.IfKeyword:
return this.parseIfStatement();
default: