Multiple fixes related to the type-checker

- Add more tests
 - Make struct-declarations type-check
 - Split environment into type bindings and variable bindings
 - Fix kind inference adding the wrong element to the env
This commit is contained in:
Sam Vervaeck 2022-09-15 11:49:53 +02:00
parent 2ec8649456
commit 2f359107c4
5 changed files with 199 additions and 102 deletions

View file

@ -22,7 +22,7 @@ import {
UnificationFailedDiagnostic,
KindMismatchDiagnostic
} from "./diagnostics";
import { assert, isEmpty } from "./util";
import { assert, isEmpty, MultiMap } from "./util";
import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl"
const MAX_TYPE_ERROR_COUNT = 5;
@ -340,10 +340,10 @@ export class TApp extends TypeBase {
super(node);
}
public static build(types: Type[]): Type {
public static build(types: Type[], node: Syntax | null = null): Type {
let result = types[0];
for (let i = 1; i < types.length; i++) {
result = new TApp(result, types[i]);
result = new TApp(result, types[i], node);
}
return result;
}
@ -716,23 +716,24 @@ type Scheme
export class TypeEnv {
private mapping = new Map<string, Scheme>();
private mapping = new MultiMap<string, [Symkind, Scheme]>();
public constructor(public parent: TypeEnv | null = null) {
}
public add(name: string, scheme: Scheme): void {
this.mapping.set(name, scheme);
public add(name: string, scheme: Scheme, kind: Symkind): void {
this.mapping.add(name, [kind, scheme]);
}
public lookup(name: string): Scheme | null {
public lookup(name: string, expectedKind: Symkind): Scheme | null {
let curr: TypeEnv | null = this;
do {
const scheme = curr.mapping.get(name);
if (scheme !== undefined) {
for (const [kind, scheme] of curr.mapping.get(name)) {
if (kind & expectedKind) {
return scheme;
}
}
curr = curr.parent;
} while(curr !== null);
return null;
@ -842,9 +843,9 @@ export class Checker {
this.contexts.pop();
}
private lookup(name: string): Scheme | null {
private lookup(name: string, kind: Symkind): Scheme | null {
const context = this.contexts[this.contexts.length-1];
return context.env.lookup(name);
return context.env.lookup(name, kind);
}
private getReturnType(): Type {
@ -871,9 +872,9 @@ export class Checker {
return scheme.type.substitute(sub);
}
private addBinding(name: string, scheme: Scheme): void {
private addBinding(name: string, scheme: Scheme, kind: Symkind): void {
const context = this.contexts[this.contexts.length-1];
context.env.add(name, scheme);
context.env.add(name, scheme, kind);
}
private inferKindFromTypeExpression(node: TypeExpression, env: KindEnv): Kind {
@ -959,6 +960,10 @@ export class Checker {
}
case SyntaxKind.StructDeclaration:
{
env.setNamed(node.name.text, this.createKindVar());
break;
}
case SyntaxKind.EnumDeclaration:
{
env.setNamed(node.name.text, this.createKindVar());
@ -1136,7 +1141,7 @@ export class Checker {
let type;
if (node.pattern.kind === SyntaxKind.WrappedOperator) {
type = this.createTypeVar();
this.addBinding(node.pattern.operator.text, new Forall([], [], type));
this.addBinding(node.pattern.operator.text, new Forall([], [], type), Symkind.Var);
} else {
type = this.inferBindings(node.pattern, [], []);
}
@ -1185,27 +1190,6 @@ export class Checker {
}
private buildVariantType(decl: EnumDeclarationElement, type: Type): Type {
const enumDecl = decl.parent as EnumDeclaration;
const kindArgs = [];
for (const _ of enumDecl.varExps) {
kindArgs.push(this.createTypeVar());
}
const variantTypes = [];
if (enumDecl.members !== null) {
for (const member of enumDecl.members) {
let variantType;
if (member === decl) {
variantType = type;
} else {
variantType = this.createTypeVar();
}
variantTypes.push(variantType);
}
}
return TApp.build([ ...kindArgs, new TVariant(enumDecl, [], []) ]);
}
public inferExpression(node: Expression): Type {
switch (node.kind) {
@ -1221,7 +1205,7 @@ export class Checker {
if (target !== null && target.kind === SyntaxKind.LetDeclaration && target.active) {
return target.type!;
}
const scheme = this.lookup(node.name.name.text);
const scheme = this.lookup(node.name.name.text, Symkind.Var);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.name.text, node.name.name));
return this.createTypeVar();
@ -1285,7 +1269,7 @@ export class Checker {
case SyntaxKind.NamedTupleExpression:
{
// TODO Only lookup constructors and skip other bindings
const scheme = this.lookup(node.name.text);
const scheme = this.lookup(node.name.text, Symkind.Var);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
@ -1310,20 +1294,19 @@ export class Checker {
case SyntaxKind.StructExpression:
{
const scope = node.getScope();
const decl = scope.lookup(node.name.text, Symkind.Constructor);
const decl = scope.lookup(node.name.text, Symkind.Type);
if (decl === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
}
assert(decl.kind === SyntaxKind.StructDeclaration || decl.kind === SyntaxKind.EnumDeclarationStructElement);
const scheme = decl.scheme;
const sub = this.createSubstitution(scheme);
const declType = this.instantiate(scheme, node, sub);
const argTypes = [];
for (const typeVar of decl.tvs) {
const newTypeVar = sub.get(typeVar)!;
assert(newTypeVar.kind === TypeKind.Var);
argTypes.push(newTypeVar);
const declType = this.instantiate(scheme, node);
const kindArgs = [];
const varExps = decl.kind === SyntaxKind.StructDeclaration
? decl.varExps : (decl.parent! as EnumDeclaration).varExps;
for (const _ of varExps) {
kindArgs.push(this.createTypeVar());
}
const fields = new Map();
for (const member of node.members) {
@ -1335,7 +1318,7 @@ export class Checker {
}
case SyntaxKind.PunnedStructExpressionField:
{
const scheme = this.lookup(member.name.text);
const scheme = this.lookup(member.name.text, Symkind.Var);
let fieldType;
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(member.name.text, member.name));
@ -1350,9 +1333,10 @@ export class Checker {
throw new Error(`Unexpected ${member}`);
}
}
let type: Type = new TRecord(decl, argTypes, fields, node);
let type: Type = TApp.build([ ...kindArgs, new TRecord(decl, [], fields, node) ]);
if (decl.kind === SyntaxKind.EnumDeclarationStructElement) {
type = this.buildVariantType(decl, type);
// TODO
// type = this.buildVariantType(decl, type);
}
this.addConstraint(
new CEqual(
@ -1366,7 +1350,7 @@ export class Checker {
case SyntaxKind.InfixExpression:
{
const scheme = this.lookup(node.operator.text);
const scheme = this.lookup(node.operator.text, Symkind.Var);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.operator.text, node.operator));
return this.createTypeVar();
@ -1398,7 +1382,7 @@ export class Checker {
case SyntaxKind.ReferenceTypeExpression:
{
const scheme = this.lookup(node.name.text);
const scheme = this.lookup(node.name.text, Symkind.Type);
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
return this.createTypeVar();
@ -1414,13 +1398,13 @@ export class Checker {
case SyntaxKind.VarTypeExpression:
{
const scheme = this.lookup(node.name.text);
const scheme = this.lookup(node.name.text, Symkind.Type);
if (scheme === null) {
if (!introduceTypeVars) {
this.diagnostics.add(new BindingNotFoudDiagnostic(node.name.text, node.name));
}
const type = this.createTypeVar();
this.addBinding(node.name.text, new Forall([], [], type));
this.addBinding(node.name.text, new Forall([], [], type), Symkind.Type);
return type;
}
assert(isEmpty(scheme.typeVars));
@ -1460,13 +1444,13 @@ export class Checker {
case SyntaxKind.BindPattern:
{
const type = this.createTypeVar();
this.addBinding(pattern.name.text, new Forall(typeVars, constraints, type));
this.addBinding(pattern.name.text, new Forall(typeVars, constraints, type), Symkind.Var);
return type;
}
case SyntaxKind.StructPattern:
{
const scheme = this.lookup(pattern.name.text);
const scheme = this.lookup(pattern.name.text, Symkind.Type);
let recordType;
if (scheme === null) {
this.diagnostics.add(new BindingNotFoudDiagnostic(pattern.name.text, pattern.name));
@ -1492,7 +1476,7 @@ export class Checker {
case SyntaxKind.PunnedStructPatternField:
{
const fieldType = this.createTypeVar();
this.addBinding(member.name.text, new Forall([], [], fieldType));
this.addBinding(member.name.text, new Forall([], [], fieldType), Symkind.Var);
this.addConstraint(
new CEqual(
new TLabeled(member.name.text, fieldType),
@ -1756,7 +1740,7 @@ export class Checker {
const kindArgs = [];
for (const varExpr of node.varExps) {
const kindArg = this.createTypeVar();
env.add(varExpr.text, new Forall([], [], kindArg));
env.add(varExpr.text, new Forall([], [], kindArg), Symkind.Type);
kindArgs.push(kindArg);
}
let elementTypes: Type[] = [];
@ -1769,7 +1753,7 @@ export class Checker {
{
const argTypes = member.elements.map(el => this.inferTypeExpression(el));
elementType = new TArrow(argTypes, TApp.build([ ...kindArgs, type ]));
parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType));
parentEnv.add(member.name.text, new Forall(typeVars, constraints, elementType), Symkind.Var);
break;
}
// TODO
@ -1780,7 +1764,7 @@ export class Checker {
}
}
this.popContext(context);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type));
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type);
break;
}
@ -1796,13 +1780,13 @@ export class Checker {
returnType: null,
};
this.pushContext(context);
for (const varExpr of node.typeVars) {
env.add(varExpr.text, new Forall([], [], this.createTypeVar()));
for (const varExpr of node.varExps) {
env.add(varExpr.text, new Forall([], [], this.createTypeVar()), Symkind.Type);
}
const type = this.inferTypeExpression(node.typeExpression);
this.popContext(context);
const scheme = new Forall(typeVars, constraints, type);
parentEnv.add(node.name.text, scheme);
parentEnv.add(node.name.text, scheme, Symkind.Type);
node.scheme = scheme;
break;
}
@ -1819,11 +1803,11 @@ export class Checker {
returnType: null,
};
this.pushContext(context);
const argTypes = [];
for (const varExpr of node.typeVars) {
const type = this.createTypeVar();
env.add(varExpr.text, new Forall([], [], type));
argTypes.push(type);
const kindArgs = [];
for (const varExpr of node.varExps) {
const kindArg = this.createTypeVar();
env.add(varExpr.text, new Forall([], [], kindArg), Symkind.Type);
kindArgs.push(kindArg);
}
const fields = new Map<string, Type>();
if (node.members !== null) {
@ -1832,11 +1816,9 @@ export class Checker {
}
}
this.popContext(context);
const type = new TRecord(node, argTypes, fields, node);
const scheme = new Forall(typeVars, constraints, type);
parentEnv.add(node.name.text, scheme);
node.tvs = argTypes;
node.scheme = scheme; //new Forall(typeVars, constraints, new TApp(type, argTypes));
const type = new TRecord(node, [], fields, node);
parentEnv.add(node.name.text, new Forall(typeVars, constraints, type), Symkind.Type);
node.scheme = new Forall(typeVars, constraints, TApp.build([ ...kindArgs, type ]));
break;
}
@ -1867,17 +1849,17 @@ export class Checker {
const b = this.createTypeVar();
const f = this.createTypeVar();
env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b)));
env.add('String', new Forall([], [], this.stringType));
env.add('Int', new Forall([], [], this.intType));
env.add('True', new Forall([], [], this.boolType));
env.add('False', new Forall([], [], this.boolType));
env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)));
env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)));
env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)));
env.add('$', new Forall([ f, a ], [], new TArrow([ new TArrow([ a ], b), a ], b)), Symkind.Var);
env.add('String', new Forall([], [], this.stringType), Symkind.Type);
env.add('Int', new Forall([], [], this.intType), Symkind.Type);
env.add('True', new Forall([], [], this.boolType), Symkind.Var);
env.add('False', new Forall([], [], this.boolType), Symkind.Var);
env.add('+', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('-', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('*', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('/', new Forall([], [], new TArrow([ this.intType, this.intType ], this.intType)), Symkind.Var);
env.add('==', new Forall([ a ], [], new TArrow([ a, a ], this.boolType)), Symkind.Var);
env.add('not', new Forall([], [], new TArrow([ this.boolType ], this.boolType)), Symkind.Var);
const graph = new LabeledDirectedHashGraph<NodeWithBindings, boolean>();
this.addReferencesToGraph(graph, node, node);
@ -1958,7 +1940,7 @@ export class Checker {
let ty2;
if (node.pattern.kind === SyntaxKind.WrappedOperator) {
ty2 = this.createTypeVar();
this.addBinding(node.pattern.operator.text, new Forall([], [], ty2));
this.addBinding(node.pattern.operator.text, new Forall([], [], ty2), Symkind.Var);
} else {
ty2 = this.inferBindings(node.pattern, typeVars, constraints);
}
@ -1974,7 +1956,7 @@ export class Checker {
&& isFunctionDeclarationLike(element)
&& graph.hasEdge(node, element, false)) {
assert(element.pattern.kind === SyntaxKind.BindPattern);
const scheme = this.lookup(element.pattern.name.text);
const scheme = this.lookup(element.pattern.name.text, Symkind.Var);
assert(scheme !== null);
this.instantiate(scheme, null);
} else {
@ -2210,14 +2192,23 @@ export class Checker {
return success;
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let leftElement: Type = left;
while (leftElement.kind === TypeKind.App) {
leftElement = leftElement.right;
}
let rightElement: Type = right;
while (rightElement.kind === TypeKind.App) {
rightElement = rightElement.right;
}
if (leftElement.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let success = true;
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of right.fields) {
if (left.fields.has(fieldName)) {
if (!unify(fieldType, left.fields.get(fieldName)!)) {
if (leftElement.fields.has(fieldName)) {
if (!unify(fieldType, leftElement.fields.get(fieldName)!)) {
success = false;
}
} else {

View file

@ -1,4 +1,4 @@
import { JSONObject, JSONValue } from "./util";
import { JSONObject, JSONValue, MultiMap } from "./util";
import type { InferContext, Type, TypeEnv } from "./checker"
export type TextSpan = [number, number];
@ -210,13 +210,12 @@ function isNodeWithScope(node: Syntax): node is NodeWithScope {
export const enum Symkind {
Var = 1,
Type = 2,
Constructor = 4,
Any = Var | Type | Constructor
Any = Var | Type
}
export class Scope {
private mapping = new Map<string, [Symkind, Syntax]>();
private mapping = new MultiMap<string, [Symkind, Syntax]>();
public constructor(
public node: NodeWithScope,
@ -236,7 +235,7 @@ export class Scope {
}
private add(name: string, node: Syntax, kind: Symkind): void {
this.mapping.set(name, [kind, node]);
this.mapping.add(name, [kind, node]);
}
private scan(node: Syntax): void {
@ -262,13 +261,14 @@ export class Scope {
this.add(node.name.text, node, Symkind.Type);
if (node.members !== null) {
for (const member of node.members) {
this.add(member.name.text, member, Symkind.Constructor);
this.add(member.name.text, member, Symkind.Var);
}
}
}
case SyntaxKind.StructDeclaration:
{
this.add(node.name.text, node, Symkind.Constructor | Symkind.Type);
this.add(node.name.text, node, Symkind.Type);
this.add(node.name.text, node, Symkind.Var);
break;
}
case SyntaxKind.LetDeclaration:
@ -326,12 +326,10 @@ export class Scope {
}
}
public lookup(name: string, expectedKind = Symkind.Any): Syntax | null {
public lookup(name: string, expectedKind: Symkind = Symkind.Any): Syntax | null {
let curr: Scope | null = this;
do {
const match = curr.mapping.get(name);
if (match !== undefined) {
const [kind, decl] = match;
for (const [kind, decl] of curr.mapping.get(name)) {
if (kind & expectedKind) {
return decl;
}
@ -1818,7 +1816,7 @@ export class StructDeclaration extends SyntaxBase {
public pubKeyword: PubKeyword | null,
public structKeyword: StructKeyword,
public name: IdentifierAlt,
public typeVars: Identifier[],
public varExps: Identifier[],
public members: StructDeclarationField[] | null,
) {
super();
@ -1946,7 +1944,7 @@ export class TypeDeclaration extends SyntaxBase {
public pubKeyword: PubKeyword | null,
public typeKeyword: TypeKeyword,
public name: IdentifierAlt,
public typeVars: Identifier[],
public varExps: Identifier[],
public equals: Equals,
public typeExpression: TypeExpression
) {

View file

@ -306,7 +306,7 @@ export class FieldMissingDiagnostic {
public readonly level = Level.Error;
public constructor(
public recordType: TRecord,
public recordType: Type,
public fieldName: string,
public node: Syntax | null,
) {

View file

@ -112,3 +112,43 @@ let fac n.
else.
return n * fac (n-"foo")
```
## Enum-declarations are correctly typed
```
enum Maybe a.
Just a
Nothing
let right_1 : Maybe Int = Just 1
let right_2 : Maybe String = Just "foo"
let wrong : Maybe Int = Just "foo"
```
## Kind inference works
```
enum Maybe a.
Just a
Nothing
let foo_1 : Maybe
let foo_2 : Maybe Int
let foo_3 : Maybe Int Int
```
## Can indirectly apply a polymorphic datatype to some type
```
enum Maybe a.
Just a
Nothing
enum App a b.
MkApp (a b)
enum Foo.
MkFoo (App Maybe Int)
let f : Foo = MkFoo (MkApp (Just 1))
```

View file

@ -118,3 +118,71 @@ export abstract class BufferedStream<T> {
}
export class MultiMap<K, V> {
private mapping = new Map<K, V[]>();
public get(key: K): V[] {
return this.mapping.get(key) ?? [];
}
public add(key: K, value: V): void {
let elements = this.mapping.get(key);
if (elements === undefined) {
elements = [];
this.mapping.set(key, elements);
}
elements.push(value);
}
public has(key: K, value?: V): boolean {
if (value === undefined) {
return this.mapping.has(key);
}
const elements = this.mapping.get(key);
if (elements === undefined) {
return false;
}
return elements.indexOf(value) !== -1;
}
public keys(): Iterable<K> {
return this.mapping.keys();
}
public *values(): Iterable<V> {
for (const elements of this.mapping.values()) {
yield* elements;
}
}
public *[Symbol.iterator](): Iterable<[K, V]> {
for (const [key, elements] of this.mapping) {
for (const value of elements) {
yield [key, value];
}
}
}
public delete(key: K, value?: V): number {
const elements = this.mapping.get(key);
if (elements === undefined) {
return 0;
}
if (value === undefined) {
this.mapping.delete(key);
return elements.length;
}
const i = elements.indexOf(value);
if (i !== -1) {
elements.splice(i, 1);
if (elements.length === 0) {
this.mapping.delete(key);
}
return 1;
}
return 0;
}
}