Improve type-checking of struct declarations/expressions

This commit is contained in:
Sam Vervaeck 2022-09-08 23:32:25 +02:00
parent e5563dd33d
commit 80bfc5f57b
2 changed files with 87 additions and 34 deletions

View file

@ -9,7 +9,7 @@ import {
SyntaxKind,
TypeExpression
} from "./cst";
import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, describeType, Diagnostics, UnificationFailedDiagnostic } from "./diagnostics";
import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, describeType, Diagnostics, FieldDoesNotExistDiagnostic, FieldMissingDiagnostic, UnificationFailedDiagnostic } from "./diagnostics";
import { assert, isEmpty } from "./util";
import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl"
@ -743,21 +743,13 @@ export class Checker {
return new TAny();
}
const recordType = this.instantiate(scheme, node);
const type = this.createTypeVar();
assert(recordType.kind === TypeKind.Record);
const fields = new Map();
for (const member of node.members) {
switch (member.kind) {
case SyntaxKind.StructExpressionField:
{
this.addConstraint(
new CEqual(
new TLabeled(
member.name.text,
this.inferExpression(member.expression)
),
type,
member,
)
);
fields.set(member.name.text, this.inferExpression(member.expression));
break;
}
case SyntaxKind.PunnedStructExpressionField:
@ -770,19 +762,14 @@ export class Checker {
} else {
fieldType = this.instantiate(scheme, member);
}
this.addConstraint(
new CEqual(
fieldType,
type,
member
)
);
fields.set(member.name.text, fieldType);
break;
}
default:
throw new Error(`Unexpected ${member}`);
}
}
const type = new TRecord(recordType.decl, fields);
this.addConstraint(
new CEqual(
recordType,
@ -1342,6 +1329,7 @@ export class Checker {
case ConstraintKind.Equal:
{
// constraint.dump();
if (!this.unify(constraint.left, constraint.right, solution, constraint)) {
errorCount++;
if (errorCount === MAX_TYPE_ERROR_COUNT) {
@ -1423,9 +1411,10 @@ export class Checker {
}
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) {
//const remaining = new Set(right.fields.keys());
let success = false;
// This works like an ordinary union-find algorithm where an additional
// property 'fields' is carried over from the child nodes to the
// ever-changing root node.
const root = left.find();
right.parent = root;
if (root.fields === undefined) {
@ -1447,12 +1436,12 @@ export class Checker {
return success;
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let success = true;
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) {
if (left.decl !== right.decl) {
this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
return false;
}
let success = true;
const remaining = new Set(right.fields.keys());
for (const [fieldName, fieldType] of left.fields) {
if (right.fields.has(fieldName)) {
@ -1460,13 +1449,29 @@ export class Checker {
success = false;
}
remaining.delete(fieldName);
} else {
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName));
success = false;
}
}
for (const fieldName of remaining) {
this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName));
}
return success;
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let success = true;
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of left.fields) {
if (left.fields.has(fieldName)) {
if (!this.unify(left.fields.get(fieldName)!, right.fields.get(fieldName)!, solution, constraint)) {
if (!this.unify(fieldType, left.fields.get(fieldName)!, solution, constraint)) {
success = false;
}
} else {
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName));
}
}
return success;

View file

@ -1,4 +1,5 @@
import { describe } from "yargs";
import { TypeKind, type Type, type TArrow } from "./checker";
import { Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst";
import { countDigits } from "./util";
@ -135,7 +136,7 @@ export class UnexpectedTokenDiagnostic {
public format(): string {
return ANSI_FG_RED + ANSI_BOLD + 'fatal: ' + ANSI_RESET
+ `expected ${describeExpected(this.expected)} but got ${describeActual(this.actual)}\n\n`
+ printExcerpt(this.file, this.actual.getRange()) + '\n';
+ printNode(this.actual) + '\n';
}
}
@ -152,10 +153,9 @@ export class BindingNotFoudDiagnostic {
}
public format(): string {
const file = this.node.getSourceFile().getFile();
return ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET
+ `binding '${this.name}' was not found.\n\n`
+ printExcerpt(file, this.node.getRange()) + '\n';
+ printNode(this.node) + '\n';
}
}
@ -229,16 +229,14 @@ export class UnificationFailedDiagnostic {
public format(): string {
const node = this.nodes[0];
const file = node.getSourceFile().getFile();
let out = ANSI_FG_RED + ANSI_BOLD + `error: ` + ANSI_RESET
+ `unification of ` + ANSI_FG_GREEN + describeType(this.left) + ANSI_RESET
+ ' and ' + ANSI_FG_GREEN + describeType(this.right) + ANSI_RESET + ' failed.\n\n'
+ printExcerpt(file, node.getRange()) + '\n';
+ printNode(node) + '\n';
for (let i = 1; i < this.nodes.length; i++) {
const node = this.nodes[i];
const file = node.getSourceFile().getFile();
out += ' ... in an instantiation of the following expression\n\n'
out += printExcerpt(file, node.getRange(), { indentation: i === 0 ? ' ' : ' ' }) + '\n';
out += printNode(node, { indentation: i === 0 ? ' ' : ' ' }) + '\n';
}
return out;
}
@ -267,12 +265,52 @@ export class ArityMismatchDiagnostic {
}
export class FieldMissingDiagnostic {
public readonly level = Level.Error;
public constructor(
public recordType: TRecord,
public fieldName: string,
) {
}
public format(): string {
return ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET
+ `field '${this.fieldName}' is missing from `
+ describeType(this.recordType) + '\n\n';
}
}
export class FieldDoesNotExistDiagnostic {
public readonly level = Level.Error;
public constructor(
public recordType: TRecord,
public fieldName: string,
) {
}
public format(): string {
return ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET
+ `field '${this.fieldName}' does not exist on type `
+ describeType(this.recordType) + '\n\n';
}
}
export type Diagnostic
= UnexpectedCharDiagnostic
| BindingNotFoudDiagnostic
| UnificationFailedDiagnostic
| UnexpectedTokenDiagnostic
| ArityMismatchDiagnostic
| FieldMissingDiagnostic
| FieldDoesNotExistDiagnostic
export interface Diagnostics {
add(diagnostic: Diagnostic): void;
@ -309,6 +347,16 @@ export class ConsoleDiagnostics {
}
interface PrintExcerptOptions {
indentation?: string;
extraLineCount?: number;
}
function printNode(node: Syntax, options?: PrintExcerptOptions): string {
const file = node.getSourceFile().getFile();
return printExcerpt(file, node.getRange(), options);
}
function printExcerpt(file: TextFile, span: TextRange, { indentation = ' ', extraLineCount = 2 } = {}): string {
let out = '';
const content = file.text;