Multiple enhancements

- Make record expressions anonymous
 - Introduce `TNominal`
 - Add experimental support for type declarations (fixes #32)
 - Fix inference of StructDeclaration
This commit is contained in:
Sam Vervaeck 2022-09-15 20:33:34 +02:00
parent 3152db9d32
commit 2d10ceedc9
4 changed files with 130 additions and 182 deletions

View file

@ -1,4 +1,5 @@
import {
Declaration,
EnumDeclaration,
EnumDeclarationStructElement,
Expression,
@ -25,6 +26,8 @@ import {
import { assert, isEmpty, MultiMap } from "./util";
import { Analyser } from "./analysis";
// TODO check that the order by which kindArgs are inserted is correct
const MAX_TYPE_ERROR_COUNT = 5;
export enum TypeKind {
@ -36,7 +39,7 @@ export enum TypeKind {
Labeled,
Record,
App,
Variant,
Nominal,
}
abstract class TypeBase {
@ -277,8 +280,6 @@ export class TRecord extends TypeBase {
public readonly kind = TypeKind.Record;
public constructor(
public decl: StructDeclaration | EnumDeclarationStructElement,
public kindArgs: TVar[],
public fields: Map<string, Type>,
public node: Syntax | null = null,
) {
@ -293,8 +294,6 @@ export class TRecord extends TypeBase {
public shallowClone(): TRecord {
return new TRecord(
this.decl,
this.kindArgs,
this.fields,
this.node
);
@ -302,15 +301,6 @@ export class TRecord extends TypeBase {
public substitute(sub: TVSub): Type {
let changed = false;
const newTypeVars = [];
for (const typeVar of this.kindArgs) {
const newTypeVar = typeVar.substitute(sub);
assert(newTypeVar.kind === TypeKind.Var);
if (newTypeVar !== typeVar) {
changed = true;
}
newTypeVars.push(newTypeVar);
}
const newFields = new Map();
for (const [key, type] of this.fields) {
const newType = type.substitute(sub);
@ -319,7 +309,7 @@ export class TRecord extends TypeBase {
}
newFields.set(key, newType);
}
return changed ? new TRecord(this.decl, newTypeVars, newFields, this.node) : this;
return changed ? new TRecord(newFields, this.node) : this;
}
}
@ -336,12 +326,11 @@ export class TApp extends TypeBase {
super(node);
}
public static build(types: Type[], node: Syntax | null = null): Type {
let result = types[0];
for (let i = 1; i < types.length; i++) {
result = new TApp(result, types[i], node);
public static build(resultType: Type, types: Type[], node: Syntax | null = null): Type {
for (let i = 0; i < types.length; i++) {
resultType = new TApp(types[i], resultType, node);
}
return result;
return resultType;
}
public *getTypeVars(): Iterable<TVar> {
@ -372,54 +361,30 @@ export class TApp extends TypeBase {
}
export class TVariant extends TypeBase {
export class TNominal extends TypeBase {
public readonly kind = TypeKind.Variant;
public readonly kind = TypeKind.Nominal;
public constructor(
public decl: EnumDeclaration,
public kindArgs: Type[],
public elementTypes: Type[],
public decl: Declaration,
public node: Syntax | null = null,
) {
super(node);
}
public *getTypeVars(): Iterable<TVar> {
for (const elementType of this.elementTypes) {
yield* elementType.getTypeVars();
}
}
public shallowClone(): Type {
return new TVariant(
return new TNominal(
this.decl,
this.kindArgs,
this.elementTypes,
this.node,
);
}
public substitute(sub: TVSub): Type {
let changed = false;
const newTypeVars = [];
for (const kindArg of this.kindArgs) {
const newTypeVar = kindArg.substitute(sub);
assert(newTypeVar.kind === TypeKind.Var);
if (newTypeVar !== kindArg) {
changed = true;
}
newTypeVars.push(newTypeVar);
}
const newElementTypes = [];
for (const elementType of this.elementTypes) {
const newElementType = elementType.substitute(sub);
if (newElementType !== elementType) {
changed = true;
}
newElementTypes.push(newElementType);
}
return changed ? new TVariant(this.decl, newTypeVars, newElementTypes, this.node) : this;
return this;
}
}
@ -432,16 +397,7 @@ export type Type
| TLabeled
| TRecord
| TApp
| TVariant
type KindedType
= TRecord
| TVariant
function isKindedType(type: Type): type is KindedType {
return type.kind === TypeKind.Variant
|| type.kind === TypeKind.Record;
}
| TNominal
export const enum KindType {
Star,
@ -973,7 +929,7 @@ export class Checker {
}
}
private inferKind(node: Syntax, env: KindEnv): void {
switch (node.kind) {
case SyntaxKind.SourceFile:
@ -985,7 +941,21 @@ export class Checker {
}
case SyntaxKind.StructDeclaration:
{
// TODO
const declKind = env.lookup(node.name.text)!;
const innerEnv = new KindEnv(env);
let kind: Kind = new KStar();
for (let i = node.varExps.length-1; i >= 0; i--) {
const varExpr = node.varExps[i];
const paramKind = this.createKindVar();
innerEnv.setNamed(varExpr.text, paramKind);
kind = new KArrow(paramKind, kind);
}
this.unifyKind(declKind, kind, node);
if (node.fields !== null) {
for (const field of node.fields) {
this.unifyKind(this.inferKindFromTypeExpression(field.typeExpr, innerEnv), new KStar(), field.typeExpr);
}
}
break;
}
case SyntaxKind.EnumDeclaration:
@ -1194,15 +1164,15 @@ export class Checker {
case SyntaxKind.ReferenceExpression:
{
assert(node.name.modulePath.length === 0);
assert(node.modulePath.length === 0);
const scope = node.getScope();
const target = scope.lookup(node.name.name.text);
const target = scope.lookup(node.name.text);
if (target !== null && target.kind === SyntaxKind.LetDeclaration && target.active) {
return target.type!;
}
const scheme = this.lookup(node.name.name.text, Symkind.Var);
const scheme = this.lookup(node.name.text, Symkind.Var);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name));
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
}
const type = this.instantiate(scheme, node);
@ -1288,21 +1258,6 @@ export class Checker {
case SyntaxKind.StructExpression:
{
const scope = node.getScope();
const decl = scope.lookup(node.name.text, Symkind.Type);
if (decl === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
}
assert(decl.kind === SyntaxKind.StructDeclaration || decl.kind === SyntaxKind.EnumDeclarationStructElement);
const scheme = decl.scheme!;
const declType = this.instantiate(scheme, node);
const kindArgs = [];
const varExps = decl.kind === SyntaxKind.StructDeclaration
? decl.varExps : (decl.parent! as EnumDeclaration).varExps;
for (const _ of varExps) {
kindArgs.push(this.createTypeVar());
}
const fields = new Map();
for (const member of node.members) {
switch (member.kind) {
@ -1328,19 +1283,7 @@ export class Checker {
throw new Error(`Unexpected ${member}`);
}
}
let type: Type = TApp.build([ ...kindArgs, new TRecord(decl, [], fields, node) ]);
if (decl.kind === SyntaxKind.EnumDeclarationStructElement) {
// TODO
// type = this.buildVariantType(decl, type);
}
this.addConstraint(
new CEqual(
declType,
type,
node,
)
);
return type;
return new TRecord(fields, node);
}
case SyntaxKind.InfixExpression:
@ -1409,10 +1352,10 @@ export class Checker {
case SyntaxKind.AppTypeExpression:
{
return TApp.build([
...node.args.map(arg => this.inferTypeExpression(arg, introduceTypeVars)),
return TApp.build(
this.inferTypeExpression(node.operator, introduceTypeVars),
]);
node.args.map(arg => this.inferTypeExpression(arg, introduceTypeVars)),
);
}
case SyntaxKind.ArrowTypeExpression:
@ -1550,7 +1493,7 @@ export class Checker {
kindArgs.push(kindArg);
}
let elementTypes: Type[] = [];
const type = new TVariant(node, [], [], node);
const type = new TNominal(node, node);
if (node.members !== null) {
for (const member of node.members) {
let elementType;
@ -1558,14 +1501,22 @@ export class Checker {
case SyntaxKind.EnumDeclarationTupleElement:
{
const argTypes = member.elements.map(el => this.inferTypeExpression(el));
elementType = new TArrow(argTypes, TApp.build([ ...kindArgs, type ]));
parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType), Symkind.Var);
elementType = new TArrow(argTypes, TApp.build(type, kindArgs));
break;
}
case SyntaxKind.EnumDeclarationStructElement:
{
const fields = new Map();
for (const field of member.fields) {
fields.set(field.name.text, this.inferTypeExpression(field.typeExpr));
}
elementType = new TArrow([ new TRecord(fields, member) ], TApp.build(type, kindArgs));
break;
}
// TODO
default:
throw new Error(`Unexpected ${member}`);
}
parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType), Symkind.Var);
elementTypes.push(elementType);
}
}
@ -1586,14 +1537,17 @@ export class Checker {
returnType: null,
};
this.pushContext(context);
const kindArgs = [];
for (const varExpr of node.varExps) {
env.add(varExpr.text, new Forall([], [], this.createTypeVar()), Symkind.Type);
const typeVar = this.createTypeVar();
kindArgs.push(typeVar);
env.add(varExpr.text, new Forall([], [], typeVar), Symkind.Type);
}
const type = this.inferTypeExpression(node.typeExpression);
console.log(describeType(type));
this.popContext(context);
const scheme = new Forall(typeVars, constraints, type);
const scheme = new Forall(typeVars, constraints, TApp.build(type, kindArgs));
parentEnv.add(node.name.text, scheme, Symkind.Type);
node.scheme = scheme;
break;
}
@ -1616,20 +1570,21 @@ export class Checker {
kindArgs.push(kindArg);
}
const fields = new Map<string, Type>();
if (node.members !== null) {
for (const member of node.members) {
if (node.fields !== null) {
for (const member of node.fields) {
fields.set(member.name.text, this.inferTypeExpression(member.typeExpr));
}
}
this.popContext(context);
const type = new TRecord(node, [], fields, node);
const type = new TNominal(node);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type);
node.scheme = new Forall(typeVars, constraints, TApp.build([ ...kindArgs, type ]));
parentEnv.add(node.name.text, new Forall(typeVars, constraints, new TArrow([ new TRecord(fields, node) ], TApp.build(type, kindArgs))), Symkind.Var);
//node.scheme = new Forall(typeVars, constraints, );
break;
}
default:
throw new Error(`Unexpected ${node}`);
throw new Error(`Unexpected ${node.constructor.name}`);
}
@ -1956,7 +1911,7 @@ export class Checker {
return success;
}
if (left.kind === TypeKind.Variant && right.kind === TypeKind.Variant) {
if (left.kind === TypeKind.Nominal && right.kind === TypeKind.Nominal) {
if (left.decl !== right.decl) {
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
return false;

View file

@ -1434,7 +1434,6 @@ export class StructExpression extends SyntaxBase {
public readonly kind = SyntaxKind.StructExpression;
public constructor(
public name: IdentifierAlt,
public lbrace: LBrace,
public members: StructExpressionElement[],
public rbrace: RBrace,
@ -1443,7 +1442,7 @@ export class StructExpression extends SyntaxBase {
}
public getFirstToken(): Token {
return this.name;
return this.lbrace;
}
public getLastToken(): Token {
@ -1481,17 +1480,21 @@ export class ReferenceExpression extends SyntaxBase {
public readonly kind = SyntaxKind.ReferenceExpression;
public constructor(
public name: QualifiedName,
public modulePath: Array<[IdentifierAlt, Dot]>,
public name: Identifier | IdentifierAlt,
) {
super();
}
public getFirstToken(): Token {
return this.name.getFirstToken();
if (this.modulePath.length > 0) {
return this.modulePath[0][0];
}
return this.name;
}
public getLastToken(): Token {
return this.name.getLastToken();
return this.name;
}
}
@ -1718,7 +1721,7 @@ export class EnumDeclarationStructElement extends SyntaxBase {
public constructor(
public name: IdentifierAlt,
public blockStart: BlockStart,
public members: StructDeclarationField[],
public fields: StructDeclarationField[],
) {
super();
}
@ -1728,8 +1731,8 @@ export class EnumDeclarationStructElement extends SyntaxBase {
}
public getLastToken(): Token {
if (this.members.length > 0) {
return this.members[this.members.length-1].getLastToken();
if (this.fields.length > 0) {
return this.fields[this.fields.length-1].getLastToken();
}
return this.blockStart;
}
@ -1830,7 +1833,7 @@ export class StructDeclaration extends SyntaxBase {
public structKeyword: StructKeyword,
public name: IdentifierAlt,
public varExps: Identifier[],
public members: StructDeclarationField[] | null,
public fields: StructDeclarationField[] | null,
) {
super();
}
@ -1843,8 +1846,8 @@ export class StructDeclaration extends SyntaxBase {
}
public getLastToken(): Token {
if (this.members && this.members.length > 0) {
return this.members[this.members.length-1].getLastToken();
if (this.fields && this.fields.length > 0) {
return this.fields[this.fields.length-1].getLastToken();
}
return this.name;
}

View file

@ -1,5 +1,4 @@
import { describe } from "yargs";
import { TypeKind, type Type, type TArrow, TRecord, Kind, KindType } from "./checker";
import { Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst";
import { countDigits, IndentWriter } from "./util";
@ -200,11 +199,21 @@ export function describeType(type: Type): string {
}
return out;
}
case TypeKind.Variant:
case TypeKind.Record:
case TypeKind.Nominal:
{
return type.decl.name.text;
}
case TypeKind.Record:
{
let out = '{ ';
let first = true;
for (const [fieldName, fieldType] of type.fields) {
if (first) first = false;
else out += ', ';
out += fieldName + ': ' + describeType(fieldType);
}
return out + ' }';
}
case TypeKind.Labeled:
{
// FIXME may need to include fields that were added during unification

View file

@ -241,7 +241,7 @@ export class Parser {
return new ConstantExpression(token);
}
public parseQualifiedName(): QualifiedName {
public parseReferenceExpression(): ReferenceExpression {
const modulePath: Array<[IdentifierAlt, Dot]> = [];
for (;;) {
const t0 = this.peekToken(1);
@ -251,12 +251,11 @@ export class Parser {
}
modulePath.push([t0, t1]);
}
const name = this.expectToken(SyntaxKind.Identifier);
return new QualifiedName(modulePath, name);
}
public parseReferenceExpression(): ReferenceExpression {
return new ReferenceExpression(this.parseQualifiedName());
const name = this.getToken();
if (name.kind !== SyntaxKind.Identifier && name.kind !== SyntaxKind.IdentifierAlt) {
this.raiseParseError(name, [ SyntaxKind.Identifier, SyntaxKind.IdentifierAlt ]);
}
return new ReferenceExpression(modulePath, name);
}
private parseExpressionWithParens(): Expression {
@ -279,65 +278,47 @@ export class Parser {
case SyntaxKind.LParen:
return this.parseExpressionWithParens();
case SyntaxKind.Identifier:
return this.parseReferenceExpression();
case SyntaxKind.IdentifierAlt:
return this.parseReferenceExpression();
case SyntaxKind.LBrace:
{
this.getToken();
const t1 = this.peekToken();
if (t1.kind === SyntaxKind.LBrace) {
this.getToken();
const fields = [];
let rbrace;
for (;;) {
const t2 = this.peekToken();
if (t2.kind === SyntaxKind.RBrace) {
this.getToken();
rbrace = t2;
break;
}
let field;
const t3 = this.getToken();
if (t3.kind === SyntaxKind.Identifier) {
const t4 = this.peekToken();
if (t4.kind === SyntaxKind.Equals) {
this.getToken();
const expression = this.parseExpression();
field = new StructExpressionField(t3, t4, expression);
} else {
field = new PunnedStructExpressionField(t3);
}
} else {
// TODO add spread fields
this.raiseParseError(t3, [ SyntaxKind.Identifier ]);
}
fields.push(field);
const t5 = this.peekToken();
if (t5.kind === SyntaxKind.Comma) {
this.getToken();
continue;
} else if (t5.kind === SyntaxKind.RBrace) {
this.getToken();
rbrace = t5;
break;
}
}
return new StructExpression(t0, t1, fields, rbrace);
}
const elements = [];
const fields = [];
let rbrace;
for (;;) {
const t2 = this.peekToken();
if (t2.kind === SyntaxKind.LineFoldEnd
|| t2.kind === SyntaxKind.Comma
|| t2.kind === SyntaxKind.RParen
|| t2.kind === SyntaxKind.RBrace
|| t2.kind === SyntaxKind.RBracket
|| isBinaryOperatorLike(t2)
|| isPrefixOperatorLike(t2)) {
if (t2.kind === SyntaxKind.RBrace) {
this.getToken();
rbrace = t2;
break;
}
let field;
const t3 = this.getToken();
if (t3.kind === SyntaxKind.Identifier) {
const t4 = this.peekToken();
if (t4.kind === SyntaxKind.Equals) {
this.getToken();
const expression = this.parseExpression();
field = new StructExpressionField(t3, t4, expression);
} else {
field = new PunnedStructExpressionField(t3);
}
} else {
// TODO add spread fields
this.raiseParseError(t3, [ SyntaxKind.Identifier ]);
}
fields.push(field);
const t5 = this.peekToken();
if (t5.kind === SyntaxKind.Comma) {
this.getToken();
continue;
} else if (t5.kind === SyntaxKind.RBrace) {
this.getToken();
rbrace = t5;
break;
}
elements.push(this.parseExpression());
}
return new NamedTupleExpression(t0, elements);
return new StructExpression(t0, fields, rbrace);
}
case SyntaxKind.Integer:
case SyntaxKind.StringLiteral: