import { FastStringMap, assert, isPlainObject, some, prettyPrintTag } from "./util"; import { SyntaxKind, Syntax, isBoltTypeExpression, BoltExpression, BoltFunctionDeclaration, BoltFunctionBodyElement, kindToString, SourceFile, isBoltExpression, isBoltMacroCall, BoltTypeExpression, BoltCallExpression, BoltSyntax, BoltMemberExpression, BoltDeclaration, isBoltDeclaration, isBoltTypeDeclaration, BoltTypeDeclaration, BoltReturnStatement, BoltIdentifier, BoltRecordDeclaration, isBoltRecordDeclaration, isBoltDeclarationLike } from "./ast"; import { getSymbolPathFromNode, ScopeType, SymbolResolver, SymbolInfo, SymbolPath } from "./resolver"; import { Value, Record } from "./evaluator"; import { SourceMap } from "module"; import { timingSafeEqual } from "crypto"; import { isRightAssoc, getReturnStatementsInFunctionBody, BoltFunctionBody, getModulePathToNode } from "./common"; import { relativeTimeThreshold } from "moment"; import { E_TOO_MANY_ARGUMENTS_FOR_FUNCTION_CALL, E_TOO_FEW_ARGUMENTS_FOR_FUNCTION_CALL, E_CANDIDATE_FUNCTION_REQUIRES_THIS_PARAMETER, E_ARGUMENT_HAS_NO_CORRESPONDING_PARAMETER, E_TYPES_NOT_ASSIGNABLE, E_TYPES_MISSING_MEMBER, E_NODE_DOES_NOT_CONTAIN_MEMBER, E_RECORD_MISSING_MEMBER, E_MUST_RETURN_A_VALUE, E_MAY_NOT_RETURN_BECAUSE_TYPE_RESOLVES_TO_VOID, E_MAY_NOT_RETURN_A_VALUE, E_MUST_RETURN_BECAUSE_TYPE_DOES_NOT_RESOLVE_TO_VOID } from "./diagnostics"; import { emitNode } from "./emitter"; import { BOLT_MAX_FIELDS_TO_PRINT } from "./constants"; // TODO For function bodies, we can do something special. // Sort the return types and find the largest types, eliminating types that fall under other types. // Next, add the resulting types as type hints to `fnReturnType`. enum TypeKind { OpaqueType, AnyType, NeverType, FunctionType, RecordType, PlainRecordFieldType, VariantType, UnionType, TupleType, } export type Type = OpaqueType | AnyType | NeverType | FunctionType | RecordType | VariantType | TupleType | UnionType | PlainRecordFieldType abstract class TypeBase { public abstract kind: TypeKind; /** * Holds the node that created this type, if any. */ public node?: Syntax constructor(public sym?: SymbolInfo) { } public [prettyPrintTag](): string { return prettyPrintType(this as Type); } } export function isType(value: any) { return typeof(value) === 'object' && value !== null && value.__IS_TYPE !== null; } export class OpaqueType extends TypeBase { public kind: TypeKind.OpaqueType = TypeKind.OpaqueType; } export class AnyType extends TypeBase { public kind: TypeKind.AnyType = TypeKind.AnyType; } export class NeverType extends TypeBase { public kind: TypeKind.NeverType = TypeKind.NeverType; } export class FunctionType extends TypeBase { public kind: TypeKind.FunctionType = TypeKind.FunctionType; constructor( public paramTypes: Type[], public returnType: Type, ) { super(); } public getParameterCount(): number { return this.paramTypes.length; } public getTypeAtParameterIndex(index: number) { if (index < 0 || index >= this.paramTypes.length) { throw new Error(`Could not get the parameter type at index ${index} because the index was out of bounds.`); } return this.paramTypes[index]; } } export class VariantType extends TypeBase { public kind: TypeKind.VariantType = TypeKind.VariantType; constructor(public elementTypes: Type[]) { super(); } public getOwnElementTypes(): IterableIterator { return this.elementTypes[Symbol.iterator](); } } export class UnionType extends TypeBase { private elements: Type[] = []; public kind: TypeKind.UnionType = TypeKind.UnionType; constructor(elements: Iterable = []) { super(); this.elements = [...elements]; } public addElement(element: Type): void { this.elements.push(element); } public getElementTypes(): IterableIterator { return this.elements[Symbol.iterator](); } } export type RecordFieldType = PlainRecordFieldType class PlainRecordFieldType extends TypeBase { public kind: TypeKind.PlainRecordFieldType = TypeKind.PlainRecordFieldType; constructor(public type: Type) { super(); } } export class RecordType extends TypeBase { public kind: TypeKind.RecordType = TypeKind.RecordType; private fieldTypes = new FastStringMap(); constructor( iterable?: Iterable<[string, RecordFieldType]>, ) { super(); if (iterable !== undefined) { for (const [name, type] of iterable) { this.fieldTypes.set(name, type); } } } public getFieldNames() { return this.fieldTypes.keys(); } public addField(name: string, type: RecordFieldType): void { this.fieldTypes.set(name, type); } public getFields() { return this.fieldTypes[Symbol.iterator](); } public hasField(name: string) { return name in this.fieldTypes; } public getFieldType(name: string) { return this.fieldTypes.get(name); } public clear(): void { this.fieldTypes.clear(); } } export class TupleType extends TypeBase { kind: TypeKind.TupleType = TypeKind.TupleType; constructor(public elementTypes: Type[] = []) { super(); } } export enum ErrorType { AssignmentError, NotARecord, TypeMismatch, TooFewArguments, TooManyArguments, MayNotReturnValue, MustReturnValue, } interface NotARecordError { type: ErrorType.NotARecord; node: Syntax; candidate: Syntax; } interface AssignmentError { type: ErrorType.AssignmentError; left: Syntax; right: Syntax; } interface TypeMismatchError { type: ErrorType.TypeMismatch; left: Type; right: Type; } interface TooManyArgumentsError { type: ErrorType.TooManyArguments; caller: Syntax; callee: Syntax; index: number; } interface TooFewArgumentsError { type: ErrorType.TooFewArguments; caller: Syntax; callee: Syntax; index: number; } interface MustReturnValueError { type: ErrorType.MustReturnValue; } interface MayNotReturnValueError { type: ErrorType.MayNotReturnValue; } export type CompileError = AssignmentError | TypeMismatchError | TooManyArgumentsError | TooFewArgumentsError | NotARecordError | MustReturnValueError | MayNotReturnValueError export interface FunctionSignature { paramTypes: Type[]; } function* getAllPossibleElementTypes(type: Type): IterableIterator { switch (type.kind) { case TypeKind.UnionType: { for (const elementType of type.getElementTypes()) { yield* getAllPossibleElementTypes(elementType); } break; } default: yield type; } } export function prettyPrintType(type: Type): string { let out = '' let hasElementType = false; for (const elementType of getAllPossibleElementTypes(type)) { hasElementType = true; if (elementType.sym !== undefined) { out += elementType.sym.name; } else { switch (elementType.kind) { case TypeKind.AnyType: { out += 'any'; break; } case TypeKind.RecordType: { out += '{' let i = 0; for (const [fieldName, fieldType] of elementType.getFields()) { out += fieldName + ': ' + prettyPrintType(fieldType); i++; if (i >= BOLT_MAX_FIELDS_TO_PRINT) { break } } out += '}' break; } default: throw new Error(`Could not pretty-print type ${TypeKind[elementType.kind]}`) } } } if (!hasElementType) { out += '()' } return out; } export class TypeChecker { private opaqueTypes = new FastStringMap(); private anyType = new AnyType(); private syntaxType = new UnionType(); // FIXME constructor(private resolver: SymbolResolver) { } public getTypeOfValue(value: Value): Type { if (typeof(value) === 'string') { const sym = this.resolver.resolveGlobalSymbol('String', ScopeType.Type); assert(sym !== null); return new OpaqueType(sym!); } else if (typeof(value) === 'bigint') { const sym = this.resolver.resolveGlobalSymbol('int', ScopeType.Type); assert(sym !== null); return new OpaqueType(sym!); } else if (typeof(value) === 'number') { const sym = this.resolver.resolveGlobalSymbol('f64', ScopeType.Type); assert(sym !== null); return new OpaqueType(sym!); } else if (value instanceof Record) { const recordType = new RecordType() for (const [fieldName, fieldValue] of value.getFields()) { recordType.addField(name, new PlainRecordFieldType(this.getTypeOfValue(fieldValue))); } return recordType; } else { throw new Error(`Could not determine type of given value.`); } } private checkTypeMatches(a: Type, b: Type) { switch (b.kind) { case TypeKind.FunctionType: if (a.kind === TypeKind.AnyType) { return true; } if (a.kind === TypeKind.FunctionType) { if (b.getParameterCount() > a.getParameterCount()) { a.node?.errors.push({ message: E_TOO_MANY_ARGUMENTS_FOR_FUNCTION_CALL, severity: 'error', args: { expected: a.getParameterCount(), actual: a.getParameterCount(), }, nested: [{ message: E_CANDIDATE_FUNCTION_REQUIRES_THIS_PARAMETER, severity: 'error', node: b.getTypeAtParameterIndex(a.getParameterCount()).node! }] }) } if (b.getParameterCount() < a.getParameterCount()) { let nested = []; for (let i = b.getParameterCount(); i < a.getParameterCount(); i++) { nested.push({ message: E_ARGUMENT_HAS_NO_CORRESPONDING_PARAMETER, severity: 'error', node: (a.node as BoltCallExpression).operands[i] }); } a.node?.errors.push({ message: E_TOO_FEW_ARGUMENTS_FOR_FUNCTION_CALL, severity: 'error', args: { expected: a.getParameterCount(), actual: b.getParameterCount(), }, nested, }); } const paramCount = a.getParameterCount(); for (let i = 0; i < paramCount; i++) { const paramA = a.getTypeAtParameterIndex(i); const paramB = b.getTypeAtParameterIndex(i); if (this.isTypeAssignableTo(paramA, paramB)) { a.node?.errors.push({ message: E_TYPES_NOT_ASSIGNABLE, severity: 'error', args: { left: a, right: b, }, node: a.node, }) } } } } } public registerSourceFile(sourceFile: SourceFile): void { for (const node of sourceFile.preorder()) { if (isBoltMacroCall(node)) { continue; // FIXME only continue when we're not in an expression context } if (isBoltExpression(node)) { node.type = this.createInitialTypeForExpression(node); } } for (const callExpr of sourceFile.findAllChildrenOfKind(SyntaxKind.BoltCallExpression)) { const callTypeSig = new FunctionType(callExpr.operands.map(op => op.type!), this.anyType); for (const callableType of this.findTypesInExpression(callExpr.operator)) { this.checkTypeMatches(callableType, callTypeSig); } } } private createInitialTypeForExpression(node: Syntax): Type { if (node.type !== undefined) { return node.type; } let resultType; switch (node.kind) { case SyntaxKind.BoltMatchExpression: { const unionType = new UnionType(); for (const matchArm of node.arms) { unionType.addElement(this.createInitialTypeForExpression(matchArm.body)); } resultType = unionType; break; } case SyntaxKind.BoltRecordDeclaration: { const recordSym = this.resolver.getSymbolForNode(node, ScopeType.Type); assert(recordSym !== null); if (this.opaqueTypes.has(recordSym!.id)) { resultType = this.opaqueTypes.get(recordSym!.id); } else { const opaqueType = new OpaqueType(name, node); this.opaqueTypes.set(recordSym!.id, opaqueType); resultType = opaqueType; } break; } case SyntaxKind.BoltFunctionExpression: { const paramTypes = node.params.map(param => { if (param.typeExpr === null) { return this.anyType; } return this.createInitialTypeForTypeExpression(param.typeExpr); }); let returnType = node.returnType === null ? this.anyType : this.createInitialTypeForTypeExpression(node.returnType); resultType = new FunctionType(paramTypes, returnType); break; } case SyntaxKind.BoltQuoteExpression: { resultType = this.syntaxType; break } case SyntaxKind.BoltMemberExpression: case SyntaxKind.BoltReferenceExpression: case SyntaxKind.BoltCallExpression: case SyntaxKind.BoltBlockExpression: { resultType = this.anyType; break; } case SyntaxKind.BoltConstantExpression: { resultType = this.getTypeOfValue(node.value); break; } default: throw new Error(`Could not create a type for node ${kindToString(node.kind)}.`); } node.type = resultType; return resultType; } private createInitialTypeForTypeExpression(node: BoltTypeExpression): Type { switch (node.kind) { case SyntaxKind.BoltLiftedTypeExpression: return this.createInitialTypeForExpression(node.expression); default: throw new Error(`Could not create a type for node ${kindToString(node.kind)}.`); } } public isVoidType(type: Type): boolean { return this.isTypeAssignableTo(new TupleType, type); } private *getTypesForMember(origNode: Syntax, fieldName: string, type: Type): IterableIterator { switch (type.kind) { case TypeKind.UnionType: { const typesMissingMember = []; for (const elementType of getAllPossibleElementTypes(type)) { let foundType = false; for (const recordType of this.getTypesForMemberNoUnionType(origNode, fieldName, elementType, false)) { yield recordType; foundType = true; } if (!foundType) { origNode.errors.push({ message: E_TYPES_MISSING_MEMBER, severity: 'error', }) } } } default: return this.getTypesForMemberNoUnionType(origNode, fieldName, type, true); } } private *getTypesForMemberNoUnionType(origNode: Syntax, fieldName: string, type: Type, hardError: boolean): IterableIterator { switch (type.kind) { case TypeKind.AnyType: break; case TypeKind.FunctionType: if (hardError) { origNode.errors.push({ message: E_TYPES_MISSING_MEMBER, severity: 'error', args: { name: fieldName, }, nested: [{ message: E_NODE_DOES_NOT_CONTAIN_MEMBER, severity: 'error', node: type.node!, }] }); } break; case TypeKind.RecordType: { if (type.hasField(fieldName)) { const fieldType = type.getFieldType(fieldName); assert(fieldType.kind === TypeKind.PlainRecordFieldType); yield (fieldType as PlainRecordFieldType).type; } else { if (hardError) { origNode.errors.push({ message: E_TYPES_MISSING_MEMBER, severity: 'error', args: { name: fieldName } }) } } break; } default: throw new Error(`I do not know how to find record member types for ${TypeKind[type.kind]}`) } } private isTypeAlwaysCallable(type: Type) { return type.kind === TypeKind.FunctionType; } private *getAllNodesForType(type: Type): IterableIterator { if (type.node !== undefined) { yield type.node; } switch (type.kind) { case TypeKind.UnionType: for (const elementType of type.getElementTypes()) { yield* this.getAllNodesForType(type); } break; default: } } private *findTypesInTypeExpression(node: BoltTypeExpression): IterableIterator { switch (node.kind) { case SyntaxKind.BoltTypeOfExpression: { yield* this.findTypesInExpression(node.expression); break; } case SyntaxKind.BoltReferenceTypeExpression: { const scope = this.resolver.getScopeSurroundingNode(node, ScopeType.Variable); assert(scope !== null); const symbolPath = getSymbolPathFromNode(node.name); const resolvedSym = this.resolver.resolveSymbolPath(symbolPath, scope!); if (resolvedSym !== null) { for (const decl of resolvedSym.declarations) { assert(isBoltTypeDeclaration(decl)); this.findTypesInTypeDeclaration(decl as BoltTypeDeclaration); } } break; } default: throw new Error(`Unexpected node type ${kindToString(node.kind)}`); } } private *findTypesInExpression(node: BoltExpression): IterableIterator { switch (node.kind) { case SyntaxKind.BoltMemberExpression: { for (const element of node.path) { for (const memberType of this.getTypesForMember(element, element.text, node.expression.type!)) { yield memberType; } } break; } case SyntaxKind.BoltMatchExpression: { const unionType = new UnionType(); for (const matchArm of node.arms) { unionType.addElement(this.createInitialTypeForExpression(matchArm.body)); } yield unionType; break; } case SyntaxKind.BoltQuoteExpression: { break; } case SyntaxKind.BoltCallExpression: { const nodeSignature = new FunctionType(node.operands.map(op => op.type!), this.anyType); for (const callableType of this.findTypesInExpression(node.operator)) { yield callableType; } break; } case SyntaxKind.BoltReferenceExpression: { const scope = this.resolver.getScopeSurroundingNode(node, ScopeType.Variable); assert(scope !== null); const symbolPath = getSymbolPathFromNode(node.name); const resolvedSym = this.resolver.resolveSymbolPath(symbolPath, scope!); if (resolvedSym !== null) { for (const decl of resolvedSym.declarations) { assert(isBoltDeclaration(decl)); yield* this.findTypesInDeclaration(decl as BoltDeclaration); } } break; } default: throw new Error(`Unexpected node type ${kindToString(node.kind)}`); } } private *findTypesInTypeDeclaration(node: BoltTypeDeclaration): IterableIterator { switch (node.kind) { case SyntaxKind.BoltTypeAliasDeclaration: { yield* this.findTypesInTypeExpression(node.typeExpr); break; } default: throw new Error(`Unexpected node type ${kindToString(node.kind)}`); } } private *findTypesInDeclaration(node: BoltDeclaration) { switch (node.kind) { case SyntaxKind.BoltVariableDeclaration: if (node.typeExpr !== null) { yield* this.findTypesInTypeExpression(node.typeExpr); } if (node.value !== null) { yield* this.findTypesInExpression(node.value); } break; case SyntaxKind.BoltFunctionDeclaration: { yield this.getFunctionReturnType(node); break; } default: throw new Error(`Could not find callable expressions in declaration ${kindToString(node.kind)}`) } } private getTypeOfExpression(node: BoltExpression): Type { return new UnionType(this.findTypesInExpression(node)); } private getFunctionReturnType(node: BoltFunctionDeclaration): Type { let returnType: Type = this.anyType; if (node.returnType !== null) { returnType = new UnionType(this.findTypesInTypeExpression(node.returnType)); } for (const returnStmt of this.getAllReturnStatementsInFunctionBody(node.body)) { if (returnStmt.value === null) { if (!this.isVoidType(returnType)) { returnStmt.errors.push({ message: E_MAY_NOT_RETURN_A_VALUE, severity: 'error', nested: [{ message: E_MAY_NOT_RETURN_BECAUSE_TYPE_RESOLVES_TO_VOID, severity: 'error', node: node.returnType !== null ? node.returnType : node, }] }); } } else { const stmtReturnType = this.getTypeOfExpression(returnStmt.value); if (!this.isTypeAssignableTo(returnType, stmtReturnType)) { if (this.isVoidType(stmtReturnType)) { returnStmt.value.errors.push({ message: E_MUST_RETURN_A_VALUE, severity: 'error', nested: [{ message: E_MUST_RETURN_BECAUSE_TYPE_DOES_NOT_RESOLVE_TO_VOID, severity: 'error', node: node.returnType !== null ? node.returnType : node, }] }) } else { returnStmt.value.errors.push({ message: E_TYPES_NOT_ASSIGNABLE, severity: 'error', args: { left: returnType, right: stmtReturnType, } }) } } } } return returnType; } private *getAllReturnStatementsInFunctionBody(body: BoltFunctionBody): IterableIterator { for (const element of body) { switch (element.kind) { case SyntaxKind.BoltReturnStatement: { yield element; break; } case SyntaxKind.BoltConditionalStatement: { for (const caseNode of element.cases) { yield* this.getAllReturnStatementsInFunctionBody(caseNode.body); } break; } case SyntaxKind.BoltExpressionStatement: break; default: throw new Error(`I did not know how to find return statements in ${kindToString(node.kind)}`); } } } private isTypeAssignableTo(left: Type, right: Type): boolean { if (left.kind === TypeKind.NeverType || right.kind === TypeKind.NeverType) { return false; } if (left.kind === TypeKind.AnyType || right.kind === TypeKind.AnyType) { return true; } if (left.kind === TypeKind.OpaqueType && right.kind === TypeKind.OpaqueType) { return left === right; } if (left.kind === TypeKind.RecordType && right.kind === TypeKind.RecordType) { for (const fieldName of left.getFieldNames()) { if (!right.hasField(fieldName)) { return false; } } for (const fieldName of right.getFieldNames()) { if (!left.hasField(fieldName)) { return false; } if (!this.isTypeAssignableTo(left.getFieldType(fieldName), right.getFieldType(fieldName))) { return false; } } return true; } if (left.kind === TypeKind.FunctionType && right.kind === TypeKind.FunctionType) { if (left.getParameterCount() !== right.getParameterCount()) { return false; } for (let i = 0; i < left.getParameterCount(); i++) { if (!this.isTypeAssignableTo(left.getTypeAtParameterIndex(i), right.getTypeAtParameterIndex(i))) { return false; } } if (!this.isTypeAssignableTo(left.returnType, right.returnType)) { return false; } return true; } return false; } }