Improve type-checking of struct declarations/expressions

This commit is contained in:
Sam Vervaeck 2022-09-08 23:32:25 +02:00
parent e5563dd33d
commit 80bfc5f57b
2 changed files with 87 additions and 34 deletions

View file

@ -9,7 +9,7 @@ import {
SyntaxKind, SyntaxKind,
TypeExpression TypeExpression
} from "./cst"; } from "./cst";
import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, describeType, Diagnostics, UnificationFailedDiagnostic } from "./diagnostics"; import { ArityMismatchDiagnostic, BindingNotFoudDiagnostic, describeType, Diagnostics, FieldDoesNotExistDiagnostic, FieldMissingDiagnostic, UnificationFailedDiagnostic } from "./diagnostics";
import { assert, isEmpty } from "./util"; import { assert, isEmpty } from "./util";
import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl" import { LabeledDirectedHashGraph, LabeledGraph, strongconnect } from "yagl"
@ -743,21 +743,13 @@ export class Checker {
return new TAny(); return new TAny();
} }
const recordType = this.instantiate(scheme, node); const recordType = this.instantiate(scheme, node);
const type = this.createTypeVar(); assert(recordType.kind === TypeKind.Record);
const fields = new Map();
for (const member of node.members) { for (const member of node.members) {
switch (member.kind) { switch (member.kind) {
case SyntaxKind.StructExpressionField: case SyntaxKind.StructExpressionField:
{ {
this.addConstraint( fields.set(member.name.text, this.inferExpression(member.expression));
new CEqual(
new TLabeled(
member.name.text,
this.inferExpression(member.expression)
),
type,
member,
)
);
break; break;
} }
case SyntaxKind.PunnedStructExpressionField: case SyntaxKind.PunnedStructExpressionField:
@ -770,19 +762,14 @@ export class Checker {
} else { } else {
fieldType = this.instantiate(scheme, member); fieldType = this.instantiate(scheme, member);
} }
this.addConstraint( fields.set(member.name.text, fieldType);
new CEqual(
fieldType,
type,
member
)
);
break; break;
} }
default: default:
throw new Error(`Unexpected ${member}`); throw new Error(`Unexpected ${member}`);
} }
} }
const type = new TRecord(recordType.decl, fields);
this.addConstraint( this.addConstraint(
new CEqual( new CEqual(
recordType, recordType,
@ -1342,6 +1329,7 @@ export class Checker {
case ConstraintKind.Equal: case ConstraintKind.Equal:
{ {
// constraint.dump();
if (!this.unify(constraint.left, constraint.right, solution, constraint)) { if (!this.unify(constraint.left, constraint.right, solution, constraint)) {
errorCount++; errorCount++;
if (errorCount === MAX_TYPE_ERROR_COUNT) { if (errorCount === MAX_TYPE_ERROR_COUNT) {
@ -1423,9 +1411,10 @@ export class Checker {
} }
if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) { if (left.kind === TypeKind.Labeled && right.kind === TypeKind.Labeled) {
//const remaining = new Set(right.fields.keys());
let success = false; let success = false;
// This works like an ordinary union-find algorithm where an additional
// property 'fields' is carried over from the child nodes to the
// ever-changing root node.
const root = left.find(); const root = left.find();
right.parent = root; right.parent = root;
if (root.fields === undefined) { if (root.fields === undefined) {
@ -1447,12 +1436,12 @@ export class Checker {
return success; return success;
} }
if (left.kind === TypeKind.Record && right.kind === TypeKind.Record) {
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) { if (left.decl !== right.decl) {
let success = true; this.diagnostics.add(new UnificationFailedDiagnostic(left, right, [...constraint.getNodes()]));
if (right.fields === undefined) { return false;
right.fields = new Map([ [ right.name, right.type ] ]);
} }
let success = true;
const remaining = new Set(right.fields.keys()); const remaining = new Set(right.fields.keys());
for (const [fieldName, fieldType] of left.fields) { for (const [fieldName, fieldType] of left.fields) {
if (right.fields.has(fieldName)) { if (right.fields.has(fieldName)) {
@ -1460,13 +1449,29 @@ export class Checker {
success = false; success = false;
} }
remaining.delete(fieldName); remaining.delete(fieldName);
} else {
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName));
success = false;
} }
} }
for (const fieldName of remaining) { for (const fieldName of remaining) {
this.diagnostics.add(new FieldDoesNotExistDiagnostic(left, fieldName));
}
return success;
}
if (left.kind === TypeKind.Record && right.kind === TypeKind.Labeled) {
let success = true;
if (right.fields === undefined) {
right.fields = new Map([ [ right.name, right.type ] ]);
}
for (const [fieldName, fieldType] of left.fields) {
if (left.fields.has(fieldName)) { if (left.fields.has(fieldName)) {
if (!this.unify(left.fields.get(fieldName)!, right.fields.get(fieldName)!, solution, constraint)) { if (!this.unify(fieldType, left.fields.get(fieldName)!, solution, constraint)) {
success = false; success = false;
} }
} else {
this.diagnostics.add(new FieldMissingDiagnostic(right, fieldName));
} }
} }
return success; return success;

View file

@ -1,4 +1,5 @@
import { describe } from "yargs";
import { TypeKind, type Type, type TArrow } from "./checker"; import { TypeKind, type Type, type TArrow } from "./checker";
import { Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst"; import { Syntax, SyntaxKind, TextFile, TextPosition, TextRange, Token } from "./cst";
import { countDigits } from "./util"; import { countDigits } from "./util";
@ -135,7 +136,7 @@ export class UnexpectedTokenDiagnostic {
public format(): string { public format(): string {
return ANSI_FG_RED + ANSI_BOLD + 'fatal: ' + ANSI_RESET return ANSI_FG_RED + ANSI_BOLD + 'fatal: ' + ANSI_RESET
+ `expected ${describeExpected(this.expected)} but got ${describeActual(this.actual)}\n\n` + `expected ${describeExpected(this.expected)} but got ${describeActual(this.actual)}\n\n`
+ printExcerpt(this.file, this.actual.getRange()) + '\n'; + printNode(this.actual) + '\n';
} }
} }
@ -152,10 +153,9 @@ export class BindingNotFoudDiagnostic {
} }
public format(): string { public format(): string {
const file = this.node.getSourceFile().getFile();
return ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET return ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET
+ `binding '${this.name}' was not found.\n\n` + `binding '${this.name}' was not found.\n\n`
+ printExcerpt(file, this.node.getRange()) + '\n'; + printNode(this.node) + '\n';
} }
} }
@ -229,16 +229,14 @@ export class UnificationFailedDiagnostic {
public format(): string { public format(): string {
const node = this.nodes[0]; const node = this.nodes[0];
const file = node.getSourceFile().getFile();
let out = ANSI_FG_RED + ANSI_BOLD + `error: ` + ANSI_RESET let out = ANSI_FG_RED + ANSI_BOLD + `error: ` + ANSI_RESET
+ `unification of ` + ANSI_FG_GREEN + describeType(this.left) + ANSI_RESET + `unification of ` + ANSI_FG_GREEN + describeType(this.left) + ANSI_RESET
+ ' and ' + ANSI_FG_GREEN + describeType(this.right) + ANSI_RESET + ' failed.\n\n' + ' and ' + ANSI_FG_GREEN + describeType(this.right) + ANSI_RESET + ' failed.\n\n'
+ printExcerpt(file, node.getRange()) + '\n'; + printNode(node) + '\n';
for (let i = 1; i < this.nodes.length; i++) { for (let i = 1; i < this.nodes.length; i++) {
const node = this.nodes[i]; const node = this.nodes[i];
const file = node.getSourceFile().getFile();
out += ' ... in an instantiation of the following expression\n\n' out += ' ... in an instantiation of the following expression\n\n'
out += printExcerpt(file, node.getRange(), { indentation: i === 0 ? ' ' : ' ' }) + '\n'; out += printNode(node, { indentation: i === 0 ? ' ' : ' ' }) + '\n';
} }
return out; return out;
} }
@ -267,12 +265,52 @@ export class ArityMismatchDiagnostic {
} }
export class FieldMissingDiagnostic {
public readonly level = Level.Error;
public constructor(
public recordType: TRecord,
public fieldName: string,
) {
}
public format(): string {
return ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET
+ `field '${this.fieldName}' is missing from `
+ describeType(this.recordType) + '\n\n';
}
}
export class FieldDoesNotExistDiagnostic {
public readonly level = Level.Error;
public constructor(
public recordType: TRecord,
public fieldName: string,
) {
}
public format(): string {
return ANSI_FG_RED + ANSI_BOLD + 'error: ' + ANSI_RESET
+ `field '${this.fieldName}' does not exist on type `
+ describeType(this.recordType) + '\n\n';
}
}
export type Diagnostic export type Diagnostic
= UnexpectedCharDiagnostic = UnexpectedCharDiagnostic
| BindingNotFoudDiagnostic | BindingNotFoudDiagnostic
| UnificationFailedDiagnostic | UnificationFailedDiagnostic
| UnexpectedTokenDiagnostic | UnexpectedTokenDiagnostic
| ArityMismatchDiagnostic | ArityMismatchDiagnostic
| FieldMissingDiagnostic
| FieldDoesNotExistDiagnostic
export interface Diagnostics { export interface Diagnostics {
add(diagnostic: Diagnostic): void; add(diagnostic: Diagnostic): void;
@ -309,6 +347,16 @@ export class ConsoleDiagnostics {
} }
interface PrintExcerptOptions {
indentation?: string;
extraLineCount?: number;
}
function printNode(node: Syntax, options?: PrintExcerptOptions): string {
const file = node.getSourceFile().getFile();
return printExcerpt(file, node.getRange(), options);
}
function printExcerpt(file: TextFile, span: TextRange, { indentation = ' ', extraLineCount = 2 } = {}): string { function printExcerpt(file: TextFile, span: TextRange, { indentation = ' ', extraLineCount = 2 } = {}): string {
let out = ''; let out = '';
const content = file.text; const content = file.text;