Major update to code base

- Add partial support for extensible records
 - Rewrite unifier in Checker.cc
 - Make use of union/find instead of a HashMap for type variables
 - Enhance diagnostic messages
 - Add a variant type
 - Add application types (TApp)
 - Some smaller bugfixes
This commit is contained in:
Sam Vervaeck 2023-05-29 20:37:23 +02:00
parent dfaa91c9b6
commit 6bd8ecff39
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
11 changed files with 1748 additions and 612 deletions

View file

@ -1497,6 +1497,9 @@ namespace bolt {
Fields(Fields), Fields(Fields),
RBrace(RBrace) {} RBrace(RBrace) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
}; };
class Statement : public Node { class Statement : public Node {
@ -1800,6 +1803,7 @@ namespace bolt {
class PubKeyword* PubKeyword; class PubKeyword* PubKeyword;
class StructKeyword* StructKeyword; class StructKeyword* StructKeyword;
IdentifierAlt* Name; IdentifierAlt* Name;
std::vector<VarTypeExpression*> Vars;
class BlockStart* BlockStart; class BlockStart* BlockStart;
std::vector<RecordDeclarationField*> Fields; std::vector<RecordDeclarationField*> Fields;
@ -1807,12 +1811,14 @@ namespace bolt {
class PubKeyword* PubKeyword, class PubKeyword* PubKeyword,
class StructKeyword* StructKeyword, class StructKeyword* StructKeyword,
IdentifierAlt* Name, IdentifierAlt* Name,
std::vector<VarTypeExpression*> Vars,
class BlockStart* BlockStart, class BlockStart* BlockStart,
std::vector<RecordDeclarationField*> Fields std::vector<RecordDeclarationField*> Fields
): Node(NodeKind::RecordDeclaration), ): Node(NodeKind::RecordDeclaration),
PubKeyword(PubKeyword), PubKeyword(PubKeyword),
StructKeyword(StructKeyword), StructKeyword(StructKeyword),
Name(Name), Name(Name),
Vars(Vars),
BlockStart(BlockStart), BlockStart(BlockStart),
Fields(Fields) {} Fields(Fields) {}

View file

@ -166,6 +166,9 @@ namespace bolt {
class Checker { class Checker {
friend class Unifier;
friend class UnificationFrame;
const LanguageConfig& Config; const LanguageConfig& Config;
DiagnosticEngine& DE; DiagnosticEngine& DE;
@ -178,14 +181,10 @@ namespace bolt {
Graph<Node*> RefGraph; Graph<Node*> RefGraph;
std::unordered_map<Node*, InferContext*> CallGraph;
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap; std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
std::vector<InferContext*> Contexts; std::vector<InferContext*> Contexts;
TVSub Solution;
/** /**
* The queue that is used during solving to store any unsolved constraints. * The queue that is used during solving to store any unsolved constraints.
*/ */
@ -208,15 +207,14 @@ namespace bolt {
Type* inferTypeExpression(TypeExpression* TE); Type* inferTypeExpression(TypeExpression* TE);
Type* inferLiteral(Literal* Lit); Type* inferLiteral(Literal* Lit);
void inferBindings(Pattern* Pattern, Type* T, ConstraintSet* Constraints, TVSet* TVs); Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);
void inferBindings(Pattern* Pattern, Type* T);
void infer(Node* node); void infer(Node* node);
void inferLetDeclaration(LetDeclaration* N); void inferLetDeclaration(LetDeclaration* N);
Constraint* convertToConstraint(ConstraintExpression* C); Constraint* convertToConstraint(ConstraintExpression* C);
TCon* createPrimConType(); TCon* createConType(ByteString Name);
TVar* createTypeVar(); TVar* createTypeVar();
TVarRigid* createRigidVar(ByteString Name); TVarRigid* createRigidVar(ByteString Name);
InferContext* createInferContext(TVSet* TVs = new TVSet, ConstraintSet* Constraints = new ConstraintSet); InferContext* createInferContext(TVSet* TVs = new TVSet, ConstraintSet* Constraints = new ConstraintSet);
@ -239,8 +237,6 @@ namespace bolt {
*/ */
Type* lookupMono(ByteString Name); Type* lookupMono(ByteString Name);
InferContext* lookupCall(Node* Source, SymbolPath Path);
/** /**
* Get the return type for the current context. If none could be found, the program will abort. * Get the return type for the current context. If none could be found, the program will abort.
*/ */
@ -252,10 +248,6 @@ namespace bolt {
void propagateClasses(TypeclassContext& Classes, Type* Ty); void propagateClasses(TypeclassContext& Classes, Type* Ty);
void propagateClassTycon(TypeclassId& Class, TCon* Ty); void propagateClassTycon(TypeclassId& Class, TCon* Ty);
Type* simplify(Type* Ty);
Type* find(Type* Ty);
/** /**
* Assign a type to a unification variable. * Assign a type to a unification variable.
* *
@ -268,18 +260,21 @@ namespace bolt {
*/ */
void join(TVar* A, Type* B); void join(TVar* A, Type* B);
// Unification parameters
Type* OrigLeft; Type* OrigLeft;
Type* OrigRight; Type* OrigRight;
TypePath LeftPath; TypePath LeftPath;
TypePath RightPath; TypePath RightPath;
ByteString CurrentFieldName;
Node* Source; Node* Source;
bool unify(Type* A, Type* B); bool unify(Type* A, Type* B);
void unifyError(); void unifyError();
void solveCEqual(CEqual* C); void solveCEqual(CEqual* C);
void solve(Constraint* Constraint, TVSub& Solution); void solve(Constraint* Constraint);
void populate(SourceFile* SF); void populate(SourceFile* SF);
@ -293,17 +288,22 @@ namespace bolt {
Checker(const LanguageConfig& Config, DiagnosticEngine& DE); Checker(const LanguageConfig& Config, DiagnosticEngine& DE);
/**
* \internal
*/
Type* simplifyType(Type* Ty);
void check(SourceFile* SF); void check(SourceFile* SF);
inline Type* getBoolType() { inline Type* getBoolType() const {
return BoolType; return BoolType;
} }
inline Type* getStringType() { inline Type* getStringType() const {
return StringType; return StringType;
} }
inline Type* getIntType() { inline Type* getIntType() const {
return IntType; return IntType;
} }

View file

@ -6,6 +6,8 @@
#include <iostream> #include <iostream>
#include "bolt/ByteString.hpp" #include "bolt/ByteString.hpp"
#include "bolt/CST.hpp"
#include "bolt/Type.hpp"
namespace bolt { namespace bolt {
@ -60,6 +62,98 @@ namespace bolt {
Magenta, Magenta,
}; };
enum StyleFlags : unsigned {
StyleFlags_None = 0,
StyleFlags_Bold = 1 << 0,
StyleFlags_Underline = 1 << 1,
StyleFlags_Italic = 1 << 2,
};
class Style {
unsigned Flags = StyleFlags_None;
Color FgColor = Color::None;
Color BgColor = Color::None;
public:
Color getForegroundColor() const noexcept {
return FgColor;
}
Color getBackgroundColor() const noexcept {
return BgColor;
}
void setForegroundColor(Color NewColor) noexcept {
FgColor = NewColor;
}
void setBackgroundColor(Color NewColor) noexcept {
BgColor = NewColor;
}
bool hasForegroundColor() const noexcept {
return FgColor != Color::None;
}
bool hasBackgroundColor() const noexcept {
return BgColor != Color::None;
}
void clearForegroundColor() noexcept {
FgColor = Color::None;
}
void clearBackgroundColor() noexcept {
BgColor = Color::None;
}
bool isUnderline() const noexcept {
return Flags & StyleFlags_Underline;
}
bool isItalic() const noexcept {
return Flags & StyleFlags_Italic;
}
bool isBold() const noexcept {
return Flags & StyleFlags_Bold;
}
void setUnderline(bool Enable) noexcept {
if (Enable) {
Flags |= StyleFlags_Underline;
} else {
Flags &= ~StyleFlags_Underline;
}
}
void setItalic(bool Enable) noexcept {
if (Enable) {
Flags |= StyleFlags_Italic;
} else {
Flags &= ~StyleFlags_Italic;
}
}
void setBold(bool Enable) noexcept {
if (Enable) {
Flags |= StyleFlags_Bold;
} else {
Flags &= ~StyleFlags_Bold;
}
}
void reset() noexcept {
FgColor = Color::None;
BgColor = Color::None;
Flags = 0;
}
};
/** /**
* Prints any diagnostic message that was added to it to the console. * Prints any diagnostic message that was added to it to the console.
*/ */
@ -67,8 +161,12 @@ namespace bolt {
std::ostream& Out; std::ostream& Out;
Style ActiveStyle;
void setForegroundColor(Color C); void setForegroundColor(Color C);
void setBackgroundColor(Color C); void setBackgroundColor(Color C);
void applyStyles();
void setBold(bool Enable); void setBold(bool Enable);
void setItalic(bool Enable); void setItalic(bool Enable);
void setUnderline(bool Enable); void setUnderline(bool Enable);
@ -99,6 +197,7 @@ namespace bolt {
void writePrefix(const Diagnostic& D); void writePrefix(const Diagnostic& D);
void writeBinding(const ByteString& Name); void writeBinding(const ByteString& Name);
void writeType(std::size_t I); void writeType(std::size_t I);
void writeType(const Type* Ty, const TypePath& Underline);
void writeType(const Type* Ty); void writeType(const Type* Ty);
void writeLoc(const TextFile& File, const TextLoc& Loc); void writeLoc(const TextFile& File, const TextLoc& Loc);
void writeTypeclassName(const ByteString& Name); void writeTypeclassName(const ByteString& Name);

View file

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <cwchar>
#include <vector> #include <vector>
#include <stdexcept> #include <stdexcept>
#include <memory> #include <memory>
@ -23,6 +24,7 @@ namespace bolt {
ClassNotFound, ClassNotFound,
TupleIndexOutOfRange, TupleIndexOutOfRange,
InvalidTypeToTypeclass, InvalidTypeToTypeclass,
FieldNotFound,
}; };
class Diagnostic : std::runtime_error { class Diagnostic : std::runtime_error {
@ -88,14 +90,14 @@ namespace bolt {
class UnificationErrorDiagnostic : public Diagnostic { class UnificationErrorDiagnostic : public Diagnostic {
public: public:
Type* Left; Type* OrigLeft;
Type* Right; Type* OrigRight;
TypePath LeftPath; TypePath LeftPath;
TypePath RightPath; TypePath RightPath;
Node* Source; Node* Source;
inline UnificationErrorDiagnostic(Type* Left, Type* Right, TypePath LeftPath, TypePath RightPath, Node* Source): inline UnificationErrorDiagnostic(Type* OrigLeft, Type* OrigRight, TypePath LeftPath, TypePath RightPath, Node* Source):
Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {} Diagnostic(DiagnosticKind::UnificationError), OrigLeft(OrigLeft), OrigRight(OrigRight), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {}
inline Node* getNode() const override { inline Node* getNode() const override {
return Source; return Source;
@ -171,4 +173,17 @@ namespace bolt {
}; };
class FieldNotFoundDiagnostic : public Diagnostic {
public:
ByteString Name;
Type* Ty;
TypePath Path;
Node* Source;
inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source):
Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {}
};
} }

View file

@ -82,6 +82,7 @@ namespace bolt {
MatchExpression* parseMatchExpression(); MatchExpression* parseMatchExpression();
Expression* parseMemberExpression(); Expression* parseMemberExpression();
RecordExpression* parseRecordExpression();
Expression* parsePrimitiveExpression(); Expression* parsePrimitiveExpression();
ConstraintExpression* parseConstraintExpression(); ConstraintExpression* parseConstraintExpression();

View file

@ -1,6 +1,8 @@
#pragma once #pragma once
#include <functional>
#include <type_traits>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
@ -28,10 +30,15 @@ namespace bolt {
}; };
enum class TypeIndexKind { enum class TypeIndexKind {
AppOpType,
AppArgType,
ArrowParamType, ArrowParamType,
ArrowReturnType, ArrowReturnType,
ConArg,
TupleElement, TupleElement,
FieldType,
FieldRestType,
TupleIndexType,
PresentType,
End, End,
}; };
@ -59,22 +66,42 @@ namespace bolt {
void advance(const Type* Ty); void advance(const Type* Ty);
static TypeIndex forArrowReturnType() { static TypeIndex forFieldType() {
return { TypeIndexKind::ArrowReturnType }; return { TypeIndexKind::FieldType };
}
static TypeIndex forFieldRest() {
return { TypeIndexKind::FieldRestType };
} }
static TypeIndex forArrowParamType(std::size_t I) { static TypeIndex forArrowParamType(std::size_t I) {
return { TypeIndexKind::ArrowParamType, I }; return { TypeIndexKind::ArrowParamType, I };
} }
static TypeIndex forConArg(std::size_t I) { static TypeIndex forArrowReturnType() {
return { TypeIndexKind::ConArg, I }; return { TypeIndexKind::ArrowReturnType };
} }
static TypeIndex forTupleElement(std::size_t I) { static TypeIndex forTupleElement(std::size_t I) {
return { TypeIndexKind::TupleElement, I }; return { TypeIndexKind::TupleElement, I };
} }
static TypeIndex forAppOpType() {
return { TypeIndexKind::AppOpType };
}
static TypeIndex forAppArgType() {
return { TypeIndexKind::AppArgType };
}
static TypeIndex forTupleIndexType() {
return { TypeIndexKind::TupleIndexType };
}
static TypeIndex forPresentType() {
return { TypeIndexKind::PresentType };
}
}; };
class TypeIterator { class TypeIterator {
@ -116,9 +143,14 @@ namespace bolt {
enum class TypeKind : unsigned char { enum class TypeKind : unsigned char {
Var, Var,
Con, Con,
App,
Arrow, Arrow,
Tuple, Tuple,
TupleIndex, TupleIndex,
Field,
Nil,
Absent,
Present,
}; };
class Type { class Type {
@ -146,8 +178,18 @@ namespace bolt {
return Out; return Out;
} }
/**
* Rewrites the entire substructure of a type to another one.
*
* \param Recursive If true, a succesfull local rewritten type will be again
* rewriten until it encounters some terminals.
*/
Type* rewrite(std::function<Type*(Type*)> Fn, bool Recursive = false);
Type* substitute(const TVSub& Sub); Type* substitute(const TVSub& Sub);
Type* solve();
TypeIterator begin(); TypeIterator begin();
TypeIterator end(); TypeIterator end();
@ -176,11 +218,10 @@ namespace bolt {
public: public:
const size_t Id; const size_t Id;
std::vector<Type*> Args;
ByteString DisplayName; ByteString DisplayName;
inline TCon(const size_t Id, std::vector<Type*> Args, ByteString DisplayName): inline TCon(const size_t Id, ByteString DisplayName):
Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {} Type(TypeKind::Con), Id(Id), DisplayName(DisplayName) {}
static bool classof(const Type* Ty) { static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Con; return Ty->getKind() == TypeKind::Con;
@ -188,12 +229,30 @@ namespace bolt {
}; };
class TApp : public Type {
public:
Type* Op;
Type* Arg;
inline TApp(Type* Op, Type* Arg):
Type(TypeKind::App), Op(Op), Arg(Arg) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::App;
}
};
enum class VarKind { enum class VarKind {
Rigid, Rigid,
Unification, Unification,
}; };
class TVar : public Type { class TVar : public Type {
Type* Parent = this;
public: public:
const size_t Id; const size_t Id;
@ -208,6 +267,10 @@ namespace bolt {
return VK; return VK;
} }
Type* find();
void set(Type* Ty);
static bool classof(const Type* Ty) { static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Var; return Ty->getKind() == TypeKind::Var;
} }
@ -272,6 +335,215 @@ namespace bolt {
}; };
class TNil : public Type {
public:
inline TNil():
Type(TypeKind::Nil) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Nil;
}
};
class TField : public Type {
public:
ByteString Name;
Type* Ty;
Type* RestTy;
inline TField(
ByteString Name,
Type* Ty,
Type* RestTy
): Type(TypeKind::Field),
Name(Name),
Ty(Ty),
RestTy(RestTy) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Field;
}
};
class TAbsent : public Type {
public:
inline TAbsent():
Type(TypeKind::Absent) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Absent;
}
};
class TPresent : public Type {
public:
Type* Ty;
inline TPresent(Type* Ty):
Type(TypeKind::Present), Ty(Ty) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Present;
}
};
template<bool IsConst>
class TypeVisitorBase {
protected:
template<typename T>
using C = std::conditional<IsConst, const T, T>::type;
virtual void enterType(C<Type>* Ty) {}
virtual void exitType(C<Type>* Ty) {}
virtual void visitVarType(C<TVar>* Ty) {
visitEachChild(Ty);
}
virtual void visitAppType(C<TApp>* Ty) {
visitEachChild(Ty);
}
virtual void visitPresentType(C<TPresent>* Ty) {
visitEachChild(Ty);
}
virtual void visitConType(C<TCon>* Ty) {
visitEachChild(Ty);
}
virtual void visitArrowType(C<TArrow>* Ty) {
visitEachChild(Ty);
}
virtual void visitTupleType(C<TTuple>* Ty) {
visitEachChild(Ty);
}
virtual void visitTupleIndexType(C<TTupleIndex>* Ty) {
visitEachChild(Ty);
}
virtual void visitAbsentType(C<TAbsent>* Ty) {
visitEachChild(Ty);
}
virtual void visitFieldType(C<TField>* Ty) {
visitEachChild(Ty);
}
virtual void visitNilType(C<TNil>* Ty) {
visitEachChild(Ty);
}
public:
void visitEachChild(C<Type>* Ty) {
switch (Ty->getKind()) {
case TypeKind::Var:
case TypeKind::Absent:
case TypeKind::Nil:
case TypeKind::Con:
break;
case TypeKind::Arrow:
{
auto Arrow = static_cast<C<TArrow>*>(Ty);
for (auto I = 0; I < Arrow->ParamTypes.size(); ++I) {
visit(Arrow->ParamTypes[I]);
}
visit(Arrow->ReturnType);
break;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<C<TTuple>*>(Ty);
for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) {
visit(Tuple->ElementTypes[I]);
}
break;
}
case TypeKind::App:
{
auto App = static_cast<C<TApp>*>(Ty);
visit(App->Op);
visit(App->Arg);
break;
}
case TypeKind::Field:
{
auto Field = static_cast<C<TField>*>(Ty);
visit(Field->Ty);
visit(Field->RestTy);
break;
}
case TypeKind::Present:
{
auto Present = static_cast<C<TPresent>*>(Ty);
visit(Present->Ty);
break;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<C<TTupleIndex>*>(Ty);
visit(Index->Ty);
break;
}
}
}
void visit(C<Type>* Ty) {
enterType(Ty);
switch (Ty->getKind()) {
case TypeKind::Present:
visitPresentType(static_cast<C<TPresent>*>(Ty));
break;
case TypeKind::Absent:
visitAbsentType(static_cast<C<TAbsent>*>(Ty));
break;
case TypeKind::Nil:
visitNilType(static_cast<C<TNil>*>(Ty));
break;
case TypeKind::Field:
visitFieldType(static_cast<C<TField>*>(Ty));
break;
case TypeKind::Con:
visitConType(static_cast<C<TCon>*>(Ty));
break;
case TypeKind::Arrow:
visitArrowType(static_cast<C<TArrow>*>(Ty));
break;
case TypeKind::Var:
visitVarType(static_cast<C<TVar>*>(Ty));
break;
case TypeKind::Tuple:
visitTupleType(static_cast<C<TTuple>*>(Ty));
break;
case TypeKind::App:
visitAppType(static_cast<C<TApp>*>(Ty));
break;
case TypeKind::TupleIndex:
visitTupleIndexType(static_cast<C<TTupleIndex>*>(Ty));
break;
}
exitType(Ty);
}
virtual ~TypeVisitorBase() {}
};
using TypeVisitor = TypeVisitorBase<false>;
using ConstTypeVisitor = TypeVisitorBase<true>;
// template<typename T> // template<typename T>
// struct DerefHash { // struct DerefHash {
// std::size_t operator()(const T& Value) const noexcept { // std::size_t operator()(const T& Value) const noexcept {

View file

@ -417,6 +417,22 @@ namespace bolt {
return BlockStart; return BlockStart;
} }
Token* RecordExpressionField::getFirstToken() const {
return Name;
}
Token* RecordExpressionField::getLastToken() const {
return E->getLastToken();
}
Token* RecordExpression::getFirstToken() const {
return LBrace;
}
Token* RecordExpression::getLastToken() const {
return RBrace;
}
Token* MemberExpression::getFirstToken() const { Token* MemberExpression::getFirstToken() const {
return E->getFirstToken(); return E->getFirstToken();
} }

View file

@ -3,18 +3,23 @@
// TODO (maybe) make unficiation work like union-find in find() // TODO (maybe) make unficiation work like union-find in find()
// TODO remove Args in TCon and just use it as a constant
// TODO make TApp traversable with TupleIndex
// TODO make simplify() rewrite the types in-place such that a reference too (Bool, Int).0 becomes Bool // TODO make simplify() rewrite the types in-place such that a reference too (Bool, Int).0 becomes Bool
// TODO Fix TVSub to use TVar.Id instead of the pointer address // TODO Add a check for datatypes that create infinite structures.
// TODO Deferred diagnostics // TODO see if we can merge UnificationError diagnostics so that we get a list of **all** types that were wrong on a given node
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <stack> #include <stack>
#include <map>
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "bolt/Type.hpp"
#include "zen/config.hpp" #include "zen/config.hpp"
#include "zen/range.hpp" #include "zen/range.hpp"
@ -58,11 +63,38 @@ namespace bolt {
} }
} }
Type* Checker::simplifyType(Type* Ty) {
return Ty->rewrite([&](auto Ty) {
if (Ty->getKind() == TypeKind::Var) {
Ty = static_cast<TVar*>(Ty)->find();
}
if (Ty->getKind() == TypeKind::TupleIndex) {
auto Index = static_cast<TTupleIndex*>(Ty);
auto MaybeTuple = simplifyType(Index->Ty);
if (MaybeTuple->getKind() == TypeKind::Tuple) {
auto Tuple = static_cast<TTuple*>(MaybeTuple);
if (Index->I >= Tuple->ElementTypes.size()) {
DE.add<TupleIndexOutOfRangeDiagnostic>(Tuple, Index->I);
} else {
Ty = simplifyType(Tuple->ElementTypes[Index->I]);
}
}
}
return Ty;
}, /*Recursive=*/true);
}
Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE):
Config(Config), DE(DE) { Config(Config), DE(DE) {
BoolType = new TCon(NextConTypeId++, {}, "Bool"); BoolType = createConType("Bool");
IntType = new TCon(NextConTypeId++, {}, "Int"); IntType = createConType("Int");
StringType = new TCon(NextConTypeId++, {}, "String"); StringType = createConType("String");
} }
Scheme* Checker::lookup(ByteString Name) { Scheme* Checker::lookup(ByteString Name) {
@ -233,6 +265,92 @@ namespace bolt {
// These declarations will be handled separately in check() // These declarations will be handled separately in check()
break; break;
case NodeKind::VariantDeclaration:
{
auto Decl = static_cast<VariantDeclaration*>(X);
auto& ParentCtx = getContext();
auto Ctx = createInferContext();
Contexts.push_back(Ctx);
std::vector<TVar*> Vars;
for (auto TE: Decl->TVs) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Ctx->TVs->emplace(TV);
Vars.push_back(TV);
}
Type* Ty = createConType(Decl->Name->getCanonicalText());
// Must be added early so we can create recursive types
ParentCtx.Env.emplace(Decl->Name->getCanonicalText(), new Forall(Ty));
for (auto Member: Decl->Members) {
switch (Member->getKind()) {
case NodeKind::TupleVariantDeclarationMember:
{
auto TupleMember = static_cast<TupleVariantDeclarationMember*>(Member);
auto RetTy = Ty;
for (auto Var: Vars) {
RetTy = new TApp(RetTy, Var);
}
std::vector<Type*> ParamTypes;
for (auto Element: TupleMember->Elements) {
ParamTypes.push_back(inferTypeExpression(Element));
}
ParentCtx.Env.emplace(TupleMember->Name->getCanonicalText(), new Forall(Ctx->TVs, Ctx->Constraints, new TArrow(ParamTypes, RetTy)));
break;
}
case NodeKind::RecordVariantDeclarationMember:
{
// TODO
break;
}
default:
ZEN_UNREACHABLE
}
}
Contexts.pop_back();
break;
}
case NodeKind::RecordDeclaration:
{
auto Decl = static_cast<RecordDeclaration*>(X);
auto& ParentCtx = getContext();
auto Ctx = createInferContext();
Contexts.push_back(Ctx);
std::vector<TVar*> Vars;
for (auto TE: Decl->Vars) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Ctx->TVs->emplace(TV);
Vars.push_back(TV);
}
auto Name = Decl->Name->getCanonicalText();
auto Ty = createConType(Name);
// Must be added early so we can create recursive types
ParentCtx.Env.emplace(Name, new Forall(Ty));
// Corresponds to the logic of one branch of a VaraintDeclarationMember
Type* FieldsTy = new TNil();
for (auto Field: Decl->Fields) {
FieldsTy = new TField(Field->Name->getCanonicalText(), new TPresent(inferTypeExpression(Field->TypeExpression)), FieldsTy);
}
Type* RetTy = Ty;
for (auto TV: Vars) {
RetTy = new TApp(RetTy, TV);
}
Contexts.pop_back();
addBinding(Name, new Forall(Ctx->TVs, Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
break;
}
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
@ -313,7 +431,7 @@ namespace bolt {
// e.g. Bool, which causes the type assert to also collapse to e.g. // e.g. Bool, which causes the type assert to also collapse to e.g.
// Bool -> Bool -> Bool. // Bool -> Bool -> Bool.
for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) {
addConstraint(new CEqual(Param, TE->getType())); addConstraint(new CEqual(Param, TE->getType(), TE));
} }
} }
@ -338,12 +456,14 @@ namespace bolt {
} }
} }
Type* BindTy;
if (HasContext) { if (HasContext) {
Contexts.pop_back(); Contexts.pop_back();
inferBindings(Let->Pattern, Ty, Let->Ctx->Constraints, Let->Ctx->TVs); BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs);
} else { } else {
inferBindings(Let->Pattern, Ty); BindTy = inferPattern(Let->Pattern);
} }
addConstraint(new CEqual(BindTy, Ty, Let));
} }
@ -364,9 +484,7 @@ namespace bolt {
for (auto Param: Decl->Params) { for (auto Param: Decl->Params) {
// TODO incorporate Param->TypeAssert or make it a kind of pattern // TODO incorporate Param->TypeAssert or make it a kind of pattern
TVar* TV = createTypeVar(); ParamTypes.push_back(inferPattern(Param->Pattern));
inferBindings(Param->Pattern, TV);
ParamTypes.push_back(TV);
} }
if (Decl->Body) { if (Decl->Body) {
@ -438,6 +556,11 @@ namespace bolt {
break; break;
} }
case NodeKind::VariantDeclaration:
case NodeKind::RecordDeclaration:
// Nothing to do for a type-level declaration
break;
case NodeKind::IfStatement: case NodeKind::IfStatement:
{ {
auto IfStmt = static_cast<IfStatement*>(N); auto IfStmt = static_cast<IfStatement*>(N);
@ -482,6 +605,10 @@ namespace bolt {
} }
TCon* Checker::createConType(ByteString Name) {
return new TCon(NextConTypeId++, Name);
}
TVarRigid* Checker::createRigidVar(ByteString Name) { TVarRigid* Checker::createRigidVar(ByteString Name) {
auto TV = new TVarRigid(NextTypeVarId++, Name); auto TV = new TVarRigid(NextTypeVarId++, Name);
Contexts.back()->TVs->emplace(TV); Contexts.back()->TVs->emplace(TV);
@ -533,7 +660,7 @@ namespace bolt {
// been solved, with some unification variables being erased. To make // been solved, with some unification variables being erased. To make
// sure we instantiate unification variables that are still in use // sure we instantiate unification variables that are still in use
// we solve before substituting. // we solve before substituting.
return simplify(F->Type)->substitute(Sub); return simplifyType(F->Type)->substitute(Sub);
} }
} }
@ -568,15 +695,28 @@ namespace bolt {
case NodeKind::ReferenceTypeExpression: case NodeKind::ReferenceTypeExpression:
{ {
auto RefTE = static_cast<ReferenceTypeExpression*>(N); auto RefTE = static_cast<ReferenceTypeExpression*>(N);
auto Ty = lookupMono(RefTE->Name->getCanonicalText()); auto Scm = lookup(RefTE->Name->getCanonicalText());
if (Ty == nullptr) { Type* Ty;
if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name); DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name);
Ty = createTypeVar(); Ty = createTypeVar();
} else {
Ty = instantiate(Scm, RefTE);
} }
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
} }
case NodeKind::AppTypeExpression:
{
auto AppTE = static_cast<AppTypeExpression*>(N);
Type* Ty = inferTypeExpression(AppTE->Op);
for (auto Arg: AppTE->Args) {
Ty = new TApp(Ty, inferTypeExpression(Arg));
}
return Ty;
}
case NodeKind::VarTypeExpression: case NodeKind::VarTypeExpression:
{ {
auto VarTE = static_cast<VarTypeExpression*>(N); auto VarTE = static_cast<VarTypeExpression*>(N);
@ -588,8 +728,9 @@ namespace bolt {
Ty = createRigidVar(VarTE->Name->getCanonicalText()); Ty = createRigidVar(VarTE->Name->getCanonicalText());
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty));
} }
ZEN_ASSERT(Ty->getKind() == TypeKind::Var);
N->setType(Ty); N->setType(Ty);
return Ty; return static_cast<TVar*>(Ty);
} }
case NodeKind::TupleTypeExpression: case NodeKind::TupleTypeExpression:
@ -642,6 +783,19 @@ namespace bolt {
} }
} }
Type* sortRow(Type* Ty) {
std::map<ByteString, TField*> Fields;
while (Ty->getKind() == TypeKind::Field) {
auto Field = static_cast<TField*>(Ty);
Fields.emplace(Field->Name, Field);
Ty = Field->RestTy;
}
for (auto [Name, Field]: Fields) {
Ty = new TField(Name, Field->Ty, Ty);
}
return Ty;
}
Type* Checker::inferExpression(Expression* X) { Type* Checker::inferExpression(Expression* X) {
Type* Ty; Type* Ty;
@ -661,9 +815,10 @@ namespace bolt {
for (auto Case: Match->Cases) { for (auto Case: Match->Cases) {
auto NewCtx = createInferContext(); auto NewCtx = createInferContext();
Contexts.push_back(NewCtx); Contexts.push_back(NewCtx);
inferBindings(Case->Pattern, ValTy); auto PattTy = inferPattern(Case->Pattern);
auto ResTy = inferExpression(Case->Expression); addConstraint(new CEqual(PattTy, ValTy, X));
addConstraint(new CEqual(ResTy, Ty, Case->Expression)); auto ExprTy = inferExpression(Case->Expression);
addConstraint(new CEqual(ExprTy, Ty, Case->Expression));
Contexts.pop_back(); Contexts.pop_back();
} }
if (!Match->Value) { if (!Match->Value) {
@ -672,6 +827,17 @@ namespace bolt {
break; break;
} }
case NodeKind::RecordExpression:
{
auto Record = static_cast<RecordExpression*>(X);
Ty = new TNil();
for (auto [Field, Comma]: Record->Fields) {
Ty = new TField(Field->Name->getCanonicalText(), new TPresent(inferExpression(Field->getExpression())), Ty);
}
Ty = sortRow(Ty);
break;
}
case NodeKind::ConstantExpression: case NodeKind::ConstantExpression:
{ {
auto Const = static_cast<ConstantExpression*>(X); auto Const = static_cast<ConstantExpression*>(X);
@ -743,16 +909,20 @@ namespace bolt {
case NodeKind::MemberExpression: case NodeKind::MemberExpression:
{ {
auto Member = static_cast<MemberExpression*>(X); auto Member = static_cast<MemberExpression*>(X);
auto ExprTy = inferExpression(Member->E);
switch (Member->Name->getKind()) { switch (Member->Name->getKind()) {
case NodeKind::IntegerLiteral: case NodeKind::IntegerLiteral:
{ {
auto I = static_cast<IntegerLiteral*>(Member->Name); auto I = static_cast<IntegerLiteral*>(Member->Name);
Ty = new TTupleIndex(inferExpression(Member->E), I->getInteger()); Ty = new TTupleIndex(ExprTy, I->getInteger());
break; break;
} }
case NodeKind::Identifier: case NodeKind::Identifier:
{ {
// TODO auto K = static_cast<Identifier*>(Member->Name);
Ty = createTypeVar();
auto RestTy = createTypeVar();
addConstraint(new CEqual(new TField(K->getCanonicalText(), Ty, RestTy), ExprTy, Member));
break; break;
} }
default: default:
@ -778,9 +948,8 @@ namespace bolt {
return Ty; return Ty;
} }
void Checker::inferBindings( Type* Checker::inferPattern(
Pattern* Pattern, Pattern* Pattern,
Type* Type,
ConstraintSet* Constraints, ConstraintSet* Constraints,
TVSet* TVs TVSet* TVs
) { ) {
@ -790,15 +959,39 @@ namespace bolt {
case NodeKind::BindPattern: case NodeKind::BindPattern:
{ {
auto P = static_cast<BindPattern*>(Pattern); auto P = static_cast<BindPattern*>(Pattern);
addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Type)); auto Ty = createTypeVar();
break; addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty));
return Ty;
}
case NodeKind::NamedPattern:
{
auto P = static_cast<NamedPattern*>(Pattern);
auto Scm = lookup(P->Name->getCanonicalText());
std::vector<Type*> ParamTypes;
for (auto P2: P->Patterns) {
ParamTypes.push_back(inferPattern(P2, Constraints, TVs));
}
if (!Scm) {
DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
return createTypeVar();
}
auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar();
addConstraint(new CEqual(Ty, new TArrow(ParamTypes, RetTy), P));
return RetTy;
}
case NodeKind::NestedPattern:
{
auto P = static_cast<NestedPattern*>(Pattern);
return inferPattern(P->P, Constraints, TVs);
} }
case NodeKind::LiteralPattern: case NodeKind::LiteralPattern:
{ {
auto P = static_cast<LiteralPattern*>(Pattern); auto P = static_cast<LiteralPattern*>(Pattern);
addConstraint(new CEqual(inferLiteral(P->Literal), Type, P)); return inferLiteral(P->Literal);
break;
} }
default: default:
@ -808,10 +1001,6 @@ namespace bolt {
} }
void Checker::inferBindings(Pattern* Pattern, Type* Type) {
inferBindings(Pattern, Type, new ConstraintSet, new TVSet);
}
Type* Checker::inferLiteral(Literal* L) { Type* Checker::inferLiteral(Literal* L) {
Type* Ty; Type* Ty;
switch (L->getKind()) { switch (L->getKind()) {
@ -927,7 +1116,7 @@ namespace bolt {
// This is ugly but it works. Scan all type variables local to this // This is ugly but it works. Scan all type variables local to this
// declaration and add the classes that they require to Actual. // declaration and add the classes that they require to Actual.
for (auto Ty: *Decl->Ctx->TVs) { for (auto Ty: *Decl->Ctx->TVs) {
auto S = Ty->substitute(C.Solution); auto S = Ty->solve();
if (llvm::isa<TVar>(S)) { if (llvm::isa<TVar>(S)) {
auto TV = static_cast<TVar*>(S); auto TV = static_cast<TVar*>(S);
for (auto Class: TV->Contexts) { for (auto Class: TV->Contexts) {
@ -995,6 +1184,10 @@ namespace bolt {
} }
Type* Checker::getType(TypedNode *Node) {
return Node->getType()->solve();
}
void Checker::check(SourceFile *SF) { void Checker::check(SourceFile *SF) {
auto RootContext = createInferContext(); auto RootContext = createInferContext();
Contexts.push_back(RootContext); Contexts.push_back(RootContext);
@ -1042,11 +1235,11 @@ namespace bolt {
} }
infer(SF); infer(SF);
Contexts.pop_back(); Contexts.pop_back();
solve(new CMany(*RootContext->Constraints), Solution); solve(new CMany(*RootContext->Constraints));
checkTypeclassSigs(SF); checkTypeclassSigs(SF);
} }
void Checker::solve(Constraint* Constraint, TVSub& Solution) { void Checker::solve(Constraint* Constraint) {
Queue.push_back(Constraint); Queue.push_back(Constraint);
@ -1094,12 +1287,13 @@ namespace bolt {
if (Con1->Id != Con2-> Id) { if (Con1->Id != Con2-> Id) {
return false; return false;
} }
ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); // TODO must handle a TApp
for (auto [T1, T2]: zen::zip(Con1->Args, Con2->Args)) { // ZEN_ASSERT(Con1->Args.size() == Con2->Args.size());
if (!assignableTo(T1, T2)) { // for (auto [T1, T2]: zen::zip(Con1->Args, Con2->Args)) {
return false; // if (!assignableTo(T1, T2)) {
} // return false;
} // }
// }
return true; return true;
} }
ZEN_UNREACHABLE ZEN_UNREACHABLE
@ -1112,19 +1306,21 @@ namespace bolt {
for (auto Instance: Match->second) { for (auto Instance: Match->second) {
if (assignableTo(Ty, Instance->TypeExps[0]->getType())) { if (assignableTo(Ty, Instance->TypeExps[0]->getType())) {
std::vector<TypeclassContext> S; std::vector<TypeclassContext> S;
for (auto Arg: Ty->Args) { // TODO handle TApp
TypeclassContext Classes; // for (auto Arg: Ty->Args) {
// TODO // TypeclassContext Classes;
S.push_back(Classes); // // TODO
} // S.push_back(Classes);
// }
return S; return S;
} }
} }
} }
DE.add<InstanceNotFoundDiagnostic>(Class, Ty, Source); DE.add<InstanceNotFoundDiagnostic>(Class, Ty, Source);
for (auto Arg: Ty->Args) { // TODO handle TApp
S.push_back({}); // for (auto Arg: Ty->Args) {
} // S.push_back({});
// }
return S; return S;
} }
@ -1145,114 +1341,15 @@ namespace bolt {
void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty) { void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty) {
auto S = findInstanceContext(Ty, Class); auto S = findInstanceContext(Ty, Class);
for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) { // TODO handle TApp
propagateClasses(Classes, Arg); // for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) {
} // propagateClasses(Classes, Arg);
// }
}; };
void Checker::solveCEqual(CEqual* C) {
// std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl;
OrigLeft = C->Left;
OrigRight = C->Right;
Source = C->Source;
unify(C->Left, C->Right);
LeftPath = {};
RightPath = {};
}
Type* Checker::find(Type* Ty) {
while (Ty->getKind() == TypeKind::Var) {
auto Match = Solution.find(static_cast<TVar*>(Ty));
if (Match == Solution.end()) {
break;
}
Ty = Match->second;
}
return Ty;
}
Type* Checker::simplify(Type* Ty) {
Ty = find(Ty);
switch (Ty->getKind()) {
case TypeKind::Var:
break;
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(Ty);
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = simplify(Ty);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : Ty;
}
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(Ty);
bool Changed = false;
std::vector<Type*> NewParamTys;
for (auto ParamTy: Arrow->ParamTypes) {
auto NewParamTy = simplify(ParamTy);
if (NewParamTy != ParamTy) {
Changed = true;
}
NewParamTys.push_back(NewParamTy);
}
auto NewRetTy = simplify(Arrow->ReturnType);
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
Ty = Changed ? new TArrow(NewParamTys, NewRetTy) : Arrow;
break;
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(Ty);
bool Changed = false;
std::vector<Type*> NewArgs;
for (auto Arg: Con->Args) {
auto NewArg = simplify(Arg);
if (NewArg != Arg) {
Changed = true;
}
NewArgs.push_back(NewArg);
}
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : Ty;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<TTupleIndex*>(Ty);
auto MaybeTuple = simplify(Index->Ty);
if (llvm::isa<TTuple>(MaybeTuple)) {
auto Tuple = static_cast<TTuple*>(MaybeTuple);
if (Index->I >= Tuple->ElementTypes.size()) {
DE.add<TupleIndexOutOfRangeDiagnostic>(Tuple, Index->I);
} else {
Ty = simplify(Tuple->ElementTypes[Index->I]);
}
}
break;
}
}
return Ty;
}
void Checker::join(TVar* TV, Type* Ty) { void Checker::join(TVar* TV, Type* Ty) {
Solution[TV] = Ty; TV->set(Ty);
propagateClasses(TV->Contexts, Ty); propagateClasses(TV->Contexts, Ty);
@ -1275,16 +1372,6 @@ namespace bolt {
} }
void Checker::unifyError() {
DE.add<UnificationErrorDiagnostic>(
simplify(OrigLeft),
simplify(OrigRight),
LeftPath,
RightPath,
Source
);
}
class ArrowCursor { class ArrowCursor {
std::stack<std::tuple<TArrow*, bool>> Stack; std::stack<std::tuple<TArrow*, bool>> Stack;
@ -1328,10 +1415,115 @@ namespace bolt {
}; };
bool Checker::unify(Type* A, Type* B) { struct Unifier {
A = simplify(A); Checker& C;
B = simplify(B); CEqual* Constraint;
// Internal state used by the unifier
ByteString CurrentFieldName;
TypePath LeftPath;
TypePath RightPath;
Type* getLeft() const {
return Constraint->Left;
}
Type* getRight() const {
return Constraint->Right;
}
Node* getSource() const {
return Constraint->Source;
}
bool unify(Type* A, Type* B);
bool unifyField(Type* A, Type* B);
bool unify() {
return unify(Constraint->Left, Constraint->Right);
}
};
class UnificationFrame {
Unifier& U;
Type* A;
Type* B;
bool DidSwap = false;
public:
UnificationFrame(Unifier& U, Type* A, Type* B):
U(U), A(U.C.simplifyType(A)), B(U.C.simplifyType(B)) {}
void unifyError() {
U.C.DE.add<UnificationErrorDiagnostic>(
U.C.simplifyType(U.Constraint->Left),
U.C.simplifyType(U.Constraint->Right),
U.LeftPath,
U.RightPath,
U.Constraint->Source
);
}
void pushLeft(TypeIndex I) {
if (DidSwap) {
U.RightPath.push_back(I);
} else {
U.LeftPath.push_back(I);
}
}
void popLeft() {
if (DidSwap) {
U.RightPath.pop_back();
} else {
U.LeftPath.pop_back();
}
}
void pushRight(TypeIndex I) {
if (DidSwap) {
U.LeftPath.push_back(I);
} else {
U.RightPath.push_back(I);
}
}
void popRight() {
if (DidSwap) {
U.LeftPath.pop_back();
} else {
U.RightPath.pop_back();
}
}
void swap() {
std::swap(A, B);
DidSwap = !DidSwap;
}
bool unifyField() {
if (llvm::isa<TAbsent>(A) && llvm::isa<TAbsent>(B)) {
return true;
}
if (llvm::isa<TAbsent>(B)) {
swap();
}
if (llvm::isa<TAbsent>(A)) {
auto Present = static_cast<TPresent*>(B);
U.C.DE.add<FieldNotFoundDiagnostic>(U.CurrentFieldName, U.C.simplifyType(U.getLeft()), U.LeftPath, U.getSource());
return false;
}
auto Present1 = static_cast<TPresent*>(A);
auto Present2 = static_cast<TPresent*>(B);
return U.unify(Present1->Ty, Present2->Ty);
}
bool unify() {
if (llvm::isa<TVar>(A) && llvm::isa<TVar>(B)) { if (llvm::isa<TVar>(A) && llvm::isa<TVar>(B)) {
auto Var1 = static_cast<TVar*>(A); auto Var1 = static_cast<TVar*>(A);
@ -1355,11 +1547,15 @@ namespace bolt {
From = Var1; From = Var1;
} }
if (From->Id != To->Id) { if (From->Id != To->Id) {
join(From, To); U.C.join(From, To);
} }
return true; return true;
} }
if (llvm::isa<TVar>(B)) {
swap();
}
if (llvm::isa<TVar>(A)) { if (llvm::isa<TVar>(A)) {
auto TV = static_cast<TVar*>(A); auto TV = static_cast<TVar*>(A);
@ -1380,18 +1576,14 @@ namespace bolt {
return false; return false;
} }
join(TV, B); U.C.join(TV, B);
return true; return true;
} }
if (llvm::isa<TVar>(B)) {
return unify(B, A);
}
if (llvm::isa<TArrow>(A) && llvm::isa<TArrow>(B)) { if (llvm::isa<TArrow>(A) && llvm::isa<TArrow>(B)) {
auto C1 = ArrowCursor(static_cast<TArrow*>(A), LeftPath); auto C1 = ArrowCursor(static_cast<TArrow*>(A), DidSwap ? U.RightPath : U.LeftPath);
auto C2 = ArrowCursor(static_cast<TArrow*>(B), RightPath); auto C2 = ArrowCursor(static_cast<TArrow*>(B), DidSwap ? U.LeftPath : U.RightPath);
bool Success = true; bool Success = true;
for (;;) { for (;;) {
auto T1 = C1.next(); auto T1 = C1.next();
@ -1404,7 +1596,7 @@ namespace bolt {
Success = false; Success = false;
break; break;
} }
if (!unify(T1, T2)) { if (!U.unify(T1, T2)) {
Success = false; Success = false;
} }
} }
@ -1421,15 +1613,29 @@ namespace bolt {
/* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */ /* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */
} }
if (llvm::isa<TArrow>(A)) { if (llvm::isa<TApp>(A) && llvm::isa<TApp>(B)) {
auto Arr = static_cast<TArrow*>(A); auto App1 = static_cast<TApp*>(A);
if (Arr->ParamTypes.empty()) { auto App2 = static_cast<TApp*>(B);
return unify(Arr->ReturnType, B); bool Success = true;
if (!U.unify(App1->Op, App2->Op)) {
Success = false;
} }
if (!U.unify(App1->Arg, App2->Arg)) {
Success = false;
}
return Success;
} }
if (llvm::isa<TArrow>(B)) { if (llvm::isa<TArrow>(B)) {
return unify(B, A); swap();
}
if (llvm::isa<TArrow>(A)) {
auto Arr = static_cast<TArrow*>(A);
if (Arr->ParamTypes.empty()) {
auto Success = U.unify(Arr->ReturnType, B);
return Success;
}
} }
if (llvm::isa<TTuple>(A) && llvm::isa<TTuple>(B)) { if (llvm::isa<TTuple>(A) && llvm::isa<TTuple>(B)) {
@ -1442,19 +1648,21 @@ namespace bolt {
auto Count = Tuple1->ElementTypes.size(); auto Count = Tuple1->ElementTypes.size();
bool Success = true; bool Success = true;
for (size_t I = 0; I < Count; I++) { for (size_t I = 0; I < Count; I++) {
LeftPath.push_back(TypeIndex::forTupleElement(I)); U.LeftPath.push_back(TypeIndex::forTupleElement(I));
RightPath.push_back(TypeIndex::forTupleElement(I)); U.RightPath.push_back(TypeIndex::forTupleElement(I));
if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I])) { if (!U.unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I])) {
Success = false; Success = false;
} }
LeftPath.pop_back(); U.LeftPath.pop_back();
RightPath.pop_back(); U.RightPath.pop_back();
} }
return Success; return Success;
} }
if (llvm::isa<TTupleIndex>(A) || llvm::isa<TTupleIndex>(B)) { if (llvm::isa<TTupleIndex>(A) || llvm::isa<TTupleIndex>(B)) {
Queue.push_back(C); // Type(s) could not be simplified at the beginning of this function,
// so we have to re-visit the constraint when there is more information.
U.C.Queue.push_back(U.Constraint);
return true; return true;
} }
@ -1471,18 +1679,67 @@ namespace bolt {
unifyError(); unifyError();
return false; return false;
} }
ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); return true;
auto Count = Con1->Args.size(); }
if (llvm::isa<TNil>(A) && llvm::isa<TNil>(B)) {
return true;
}
if (llvm::isa<TField>(A) && llvm::isa<TField>(B)) {
auto Field1 = static_cast<TField*>(A);
auto Field2 = static_cast<TField*>(B);
bool Success = true; bool Success = true;
for (std::size_t I = 0; I < Count; I++) { if (Field1->Name == Field2->Name) {
LeftPath.push_back(TypeIndex::forConArg(I)); U.LeftPath.push_back(TypeIndex::forFieldType());
RightPath.push_back(TypeIndex::forConArg(I)); U.RightPath.push_back(TypeIndex::forFieldType());
if (!unify(Con1->Args[I], Con2->Args[I])) { U.CurrentFieldName = Field1->Name;
if (!U.unifyField(Field1->Ty, Field2->Ty)) {
Success = false; Success = false;
} }
LeftPath.pop_back(); U.LeftPath.pop_back();
RightPath.pop_back(); U.RightPath.pop_back();
U.LeftPath.push_back(TypeIndex::forFieldRest());
U.RightPath.push_back(TypeIndex::forFieldRest());
if (!U.unify(Field1->RestTy, Field2->RestTy)) {
Success = false;
} }
U.LeftPath.pop_back();
U.RightPath.pop_back();
return Success;
}
auto NewRestTy = new TVar(U.C.NextTypeVarId++, VarKind::Unification);
pushLeft(TypeIndex::forFieldRest());
if (!U.unify(Field1->RestTy, new TField(Field2->Name, Field2->Ty, NewRestTy))) {
Success = false;
}
popLeft();
pushRight(TypeIndex::forFieldRest());
if (!U.unify(new TField(Field1->Name, Field1->Ty, NewRestTy), Field2->RestTy)) {
Success = false;
}
popRight();
return Success;
}
if (llvm::isa<TNil>(A) && llvm::isa<TField>(B)) {
swap();
}
if (llvm::isa<TField>(A) && llvm::isa<TNil>(B)) {
auto Field = static_cast<TField*>(A);
bool Success = true;
pushLeft(TypeIndex::forFieldType());
U.CurrentFieldName = Field->Name;
if (!U.unifyField(Field->Ty, new TAbsent)) {
Success = false;
}
popLeft();
pushLeft(TypeIndex::forFieldRest());
if (!U.unify(Field->RestTy, B)) {
Success = false;
}
popLeft();
return Success; return Success;
} }
@ -1490,18 +1747,24 @@ namespace bolt {
return false; return false;
} }
InferContext* Checker::lookupCall(Node* Source, SymbolPath Path) { };
auto Def = Source->getScope()->lookup(Path);
auto Match = CallGraph.find(Def); bool Unifier::unify(Type* A, Type* B) {
if (Match == CallGraph.end()) { UnificationFrame Frame { *this, A, B };
return nullptr; return Frame.unify();
}
return Match->second;
} }
Type* Checker::getType(TypedNode *Node) { bool Unifier::unifyField(Type* A, Type* B) {
return Node->getType()->substitute(Solution); UnificationFrame Frame { *this, A, B };
return Frame.unifyField();
} }
void Checker::solveCEqual(CEqual* C) {
// std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl;
Unifier A { *this, C };
A.unify();
}
} }

View file

@ -13,6 +13,7 @@
#define ANSI_RESET "\u001b[0m" #define ANSI_RESET "\u001b[0m"
#define ANSI_BOLD "\u001b[1m" #define ANSI_BOLD "\u001b[1m"
#define ANSI_ITALIC "\u001b[3m"
#define ANSI_UNDERLINE "\u001b[4m" #define ANSI_UNDERLINE "\u001b[4m"
#define ANSI_REVERSED "\u001b[7m" #define ANSI_REVERSED "\u001b[7m"
@ -107,6 +108,16 @@ namespace bolt {
return "'return'"; return "'return'";
case NodeKind::TypeKeyword: case NodeKind::TypeKeyword:
return "'type'"; return "'type'";
case NodeKind::LetDeclaration:
return "a let-declaration";
case NodeKind::CallExpression:
return "a call-expression";
case NodeKind::InfixExpression:
return "an infix-expression";
case NodeKind::ReferenceExpression:
return "a function or variable reference";
case NodeKind::MatchExpression:
return "a match-expression";
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
@ -151,16 +162,12 @@ namespace bolt {
case TypeKind::Con: case TypeKind::Con:
{ {
auto Y = static_cast<const TCon*>(Ty); auto Y = static_cast<const TCon*>(Ty);
std::ostringstream Out; return Y->DisplayName;
if (!Y->DisplayName.empty()) {
Out << Y->DisplayName;
} else {
Out << "C" << Y->Id;
} }
for (auto Arg: Y->Args) { case TypeKind::App:
Out << " " << describe(Arg); {
} auto Y = static_cast<const TApp*>(Ty);
return Out.str(); return describe(Y->Op) + " " + describe(Y->Arg);
} }
case TypeKind::Tuple: case TypeKind::Tuple:
{ {
@ -182,20 +189,36 @@ namespace bolt {
auto Y = static_cast<const TTupleIndex*>(Ty); auto Y = static_cast<const TTupleIndex*>(Ty);
return describe(Y->Ty) + "." + std::to_string(Y->I); return describe(Y->Ty) + "." + std::to_string(Y->I);
} }
case TypeKind::Nil:
return "{}";
case TypeKind::Absent:
return "Abs";
case TypeKind::Present:
{
auto Y = static_cast<const TPresent*>(Ty);
return describe(Y->Ty);
}
case TypeKind::Field:
{
auto Y = static_cast<const TField*>(Ty);
std::ostringstream out;
out << "{ " << Y->Name << ": " << describe(Y->Ty);
Ty = Y->RestTy;
while (Ty->getKind() == TypeKind::Field) {
auto Y = static_cast<const TField*>(Ty);
out << "; " + Y->Name + ": " + describe(Y->Ty);
Ty = Y->RestTy;
}
if (Ty->getKind() != TypeKind::Nil) {
out << "; " + describe(Ty);
}
out << " }";
return out.str();
}
} }
} }
DiagnosticStore::~DiagnosticStore() { void writeForegroundANSI(Color C, std::ostream& Out) {
for (auto D: Diagnostics) {
delete D;
}
}
ConsoleDiagnostics::ConsoleDiagnostics(std::ostream& Out):
Out(Out) {}
void ConsoleDiagnostics::setForegroundColor(Color C) {
if (EnableColors) {
switch (C) { switch (C) {
case Color::None: case Color::None:
break; break;
@ -225,11 +248,8 @@ namespace bolt {
break; break;
} }
} }
}
void writeBackgroundANSI(Color C, std::ostream& Out) {
void ConsoleDiagnostics::setBackgroundColor(Color C) {
if (EnableColors) {
switch (C) { switch (C) {
case Color::None: case Color::None:
break; break;
@ -259,27 +279,95 @@ namespace bolt {
break; break;
} }
} }
DiagnosticStore::~DiagnosticStore() {
for (auto D: Diagnostics) {
delete D;
}
}
ConsoleDiagnostics::ConsoleDiagnostics(std::ostream& Out):
Out(Out) {}
void ConsoleDiagnostics::setForegroundColor(Color C) {
ActiveStyle.setForegroundColor(C);
if (!EnableColors) {
return;
}
writeForegroundANSI(C, Out);
}
void ConsoleDiagnostics::setBackgroundColor(Color C) {
ActiveStyle.setBackgroundColor(C);
if (!EnableColors) {
return;
}
if (C == Color::None) {
Out << ANSI_RESET;
applyStyles();
}
writeBackgroundANSI(C, Out);
}
void ConsoleDiagnostics::applyStyles() {
if (ActiveStyle.isBold()) {
Out << ANSI_BOLD;
}
if (ActiveStyle.isUnderline()) {
Out << ANSI_UNDERLINE;
}
if (ActiveStyle.isItalic()) {
Out << ANSI_ITALIC;
}
if (ActiveStyle.hasBackgroundColor()) {
setBackgroundColor(ActiveStyle.getBackgroundColor());
}
if (ActiveStyle.hasForegroundColor()) {
setForegroundColor(ActiveStyle.getForegroundColor());
}
} }
void ConsoleDiagnostics::setBold(bool Enable) { void ConsoleDiagnostics::setBold(bool Enable) {
ActiveStyle.setBold(Enable);
if (!EnableColors) {
return;
}
if (Enable) { if (Enable) {
Out << ANSI_BOLD; Out << ANSI_BOLD;
} else {
Out << ANSI_RESET;
applyStyles();
} }
} }
void ConsoleDiagnostics::setItalic(bool Enable) { void ConsoleDiagnostics::setItalic(bool Enable) {
ActiveStyle.setItalic(Enable);
if (!EnableColors) {
return;
}
if (Enable) { if (Enable) {
// TODO Out << ANSI_ITALIC;
} else {
Out << ANSI_RESET;
applyStyles();
} }
} }
void ConsoleDiagnostics::setUnderline(bool Enable) { void ConsoleDiagnostics::setUnderline(bool Enable) {
ActiveStyle.setItalic(Enable);
if (!EnableColors) {
return;
}
if (Enable) { if (Enable) {
Out << ANSI_UNDERLINE; Out << ANSI_UNDERLINE;
} else {
Out << ANSI_RESET;
applyStyles();
} }
} }
void ConsoleDiagnostics::resetStyles() { void ConsoleDiagnostics::resetStyles() {
ActiveStyle.reset();
if (EnableColors) { if (EnableColors) {
Out << ANSI_RESET; Out << ANSI_RESET;
} }
@ -391,8 +479,159 @@ namespace bolt {
} }
void ConsoleDiagnostics::writeType(const Type* Ty) { void ConsoleDiagnostics::writeType(const Type* Ty) {
TypePath Path;
writeType(Ty, Path);
}
void ConsoleDiagnostics::writeType(const Type* Ty, const TypePath& Underline) {
setForegroundColor(Color::Green); setForegroundColor(Color::Green);
write(describe(Ty));
class TypePrinter : public ConstTypeVisitor {
TypePath Path;
ConsoleDiagnostics& W;
const TypePath& Underline;
public:
TypePrinter(ConsoleDiagnostics& W, const TypePath& Underline):
W(W), Underline(Underline) {}
bool shouldUnderline() const {
return !Underline.empty() && Path == Underline;
}
void enterType(const Type* Ty) override {
if (shouldUnderline()) {
W.setUnderline(true);
}
}
void exitType(const Type* Ty) override {
if (shouldUnderline()) {
W.setUnderline(false);
}
}
void visitAppType(const TApp *Ty) override {
auto Y = static_cast<const TApp*>(Ty);
Path.push_back(TypeIndex::forAppOpType());
visit(Y->Op);
Path.pop_back();
W.write(" ");
Path.push_back(TypeIndex::forAppArgType());
visit(Y->Arg);
Path.pop_back();
}
void visitVarType(const TVar* Ty) override {
if (Ty->getVarKind() == VarKind::Rigid) {
W.write(static_cast<const TVarRigid*>(Ty)->Name);
return;
}
W.write("a");
W.write(Ty->Id);
}
void visitConType(const TCon *Ty) override {
W.write(Ty->DisplayName);
}
void visitArrowType(const TArrow* Ty) override {
W.write("(");
bool First = true;
std::size_t I = 0;
for (auto PT: Ty->ParamTypes) {
if (First) First = false;
else W.write(", ");
Path.push_back(TypeIndex::forArrowParamType(I++));
visit(PT);
Path.pop_back();
}
W.write(") -> ");
Path.push_back(TypeIndex::forArrowReturnType());
visit(Ty->ReturnType);
Path.pop_back();
}
void visitTupleType(const TTuple *Ty) override {
W.write("(");
if (Ty->ElementTypes.size()) {
auto Iter = Ty->ElementTypes.begin();
Path.push_back(TypeIndex::forTupleElement(0));
visit(*Iter++);
Path.pop_back();
std::size_t I = 1;
while (Iter != Ty->ElementTypes.end()) {
W.write(", ");
Path.push_back(TypeIndex::forTupleElement(I++));
visit(*Iter++);
Path.pop_back();
}
}
W.write(")");
}
void visitTupleIndexType(const TTupleIndex *Ty) override {
Path.push_back(TypeIndex::forTupleIndexType());
visit(Ty->Ty);
Path.pop_back();
W.write(".");
W.write(Ty->I);
}
void visitNilType(const TNil *Ty) override {
W.write("{}");
}
void visitAbsentType(const TAbsent *Ty) override {
W.write("Abs");
}
void visitPresentType(const TPresent *Ty) override {
Path.push_back(TypeIndex::forPresentType());
visit(Ty->Ty);
Path.pop_back();
}
void visitFieldType(const TField* Ty) override {
W.write("{ ");
W.write(Ty->Name);
W.write(": ");
Path.push_back(TypeIndex::forFieldType());
visit(Ty->Ty);
Path.pop_back();
auto Ty2 = Ty->RestTy;
Path.push_back(TypeIndex::forFieldRest());
std::size_t I = 1;
while (Ty2->getKind() == TypeKind::Field) {
auto Y = static_cast<const TField*>(Ty2);
W.write("; ");
W.write(Y->Name);
W.write(": ");
Path.push_back(TypeIndex::forFieldType());
visit(Y->Ty);
Path.pop_back();
Ty2 = Y->RestTy;
Path.push_back(TypeIndex::forFieldRest());
++I;
}
if (Ty2->getKind() != TypeKind::Nil) {
W.write("; ");
visit(Ty);
}
W.write(" }");
for (auto K = 0; K < I; K++) {
Path.pop_back();
}
}
};
TypePrinter P { *this, Underline };
P.visit(Ty);
resetStyles(); resetStyles();
} }
@ -533,40 +772,51 @@ namespace bolt {
case DiagnosticKind::UnificationError: case DiagnosticKind::UnificationError:
{ {
auto E = static_cast<const UnificationErrorDiagnostic&>(D); auto E = static_cast<const UnificationErrorDiagnostic&>(D);
auto Left = E.OrigLeft->resolve(E.LeftPath);
auto Right = E.OrigRight->resolve(E.RightPath);
writePrefix(E); writePrefix(E);
auto Left = E.Left->resolve(E.LeftPath);
auto Right = E.Right->resolve(E.RightPath);
write("the types "); write("the types ");
writeType(Left); writeType(Left);
write(" and "); write(" and ");
writeType(Right); writeType(Right);
write(" failed to match\n\n"); write(" failed to match\n\n");
if (E.Source) { setForegroundColor(Color::Yellow);
setBold(true);
write(" info: ");
resetStyles();
write("due to an equality constraint on ");
write(describe(E.Source->getKind()));
write(":\n\n");
write(" - left type ");
writeType(E.OrigLeft, E.LeftPath);
write("\n");
write(" - right type ");
writeType(E.OrigRight, E.RightPath);
write("\n\n");
writeNode(E.Source); writeNode(E.Source);
Out << "\n"; write("\n");
} // if (E.Left != E.OrigLeft) {
if (!E.LeftPath.empty()) { // setForegroundColor(Color::Yellow);
setForegroundColor(Color::Yellow); // setBold(true);
setBold(true); // write(" info: ");
write(" info: "); // resetStyles();
resetStyles(); // write("the type ");
write("the type "); // writeType(E.Left);
writeType(Left); // write(" occurs in the full type ");
write(" occurs in the full type "); // writeType(E.OrigLeft);
writeType(E.Left); // write("\n\n");
write("\n\n"); // }
} // if (E.Right != E.OrigRight) {
if (!E.RightPath.empty()) { // setForegroundColor(Color::Yellow);
setForegroundColor(Color::Yellow); // setBold(true);
setBold(true); // write(" info: ");
write(" info: "); // resetStyles();
resetStyles(); // write("the type ");
write("the type "); // writeType(E.Right);
writeType(Right); // write(" occurs in the full type ");
write(" occurs in the full type "); // writeType(E.OrigRight);
writeType(E.Right); // write("\n\n");
write("\n\n"); // }
}
break; break;
} }
@ -634,6 +884,18 @@ namespace bolt {
break; break;
} }
case DiagnosticKind::FieldNotFound:
{
auto E = static_cast<const FieldNotFoundDiagnostic&>(D);
writePrefix(E);
write("the field '");
write(E.Name);
write("' was required in one type but not found in another\n\n");
writeNode(E.Source);
write("\n");
break;
}
} }
} }

View file

@ -473,6 +473,75 @@ after_tuple_element:
return new MatchExpression(static_cast<MatchKeyword*>(T0), Value, BlockStart, Cases); return new MatchExpression(static_cast<MatchKeyword*>(T0), Value, BlockStart, Cases);
} }
RecordExpression* Parser::parseRecordExpression() {
auto LBrace = expectToken<class LBrace>();
if (!LBrace) {
return nullptr;
}
RBrace* RBrace;
auto T1 = Tokens.peek();
std::vector<std::tuple<RecordExpressionField*, Comma*>> Fields;
if (T1->getKind() == NodeKind::RBrace) {
Tokens.get();
RBrace = static_cast<class RBrace*>(T1);
} else {
for (;;) {
auto Name = expectToken<Identifier>();
if (!Name) {
LBrace->unref();
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
return nullptr;
}
auto Equals = expectToken<class Equals>();
if (!Equals) {
LBrace->unref();
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
return nullptr;
}
auto E = parseExpression();
if (!E) {
LBrace->unref();
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
Equals->unref();
return nullptr;
}
auto T2 = Tokens.peek();
if (T2->getKind() == NodeKind::Comma) {
Tokens.get();
Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, static_cast<Comma*>(T2)));
} else if (T2->getKind() == NodeKind::RBrace) {
Tokens.get();
RBrace = static_cast<class RBrace*>(T2);
Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, nullptr));
break;
} else {
DE.add<UnexpectedTokenDiagnostic>(File, T2, std::vector { NodeKind::Comma, NodeKind::RBrace });
LBrace->unref();
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
Equals->unref();
E->unref();
return nullptr;
}
}
}
return new RecordExpression { LBrace, Fields, RBrace };
}
Expression* Parser::parsePrimitiveExpression() { Expression* Parser::parsePrimitiveExpression() {
auto T0 = Tokens.peek(); auto T0 = Tokens.peek();
switch (T0->getKind()) { switch (T0->getKind()) {
@ -562,9 +631,11 @@ after_tuple_elements:
case NodeKind::StringLiteral: case NodeKind::StringLiteral:
Tokens.get(); Tokens.get();
return new ConstantExpression(static_cast<Literal*>(T0)); return new ConstantExpression(static_cast<Literal*>(T0));
case NodeKind::LBrace:
return parseRecordExpression();
default: default:
// Tokens.get(); // Tokens.get();
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen, NodeKind::IntegerLiteral, NodeKind::StringLiteral }); DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen, NodeKind::LBrace, NodeKind::IntegerLiteral, NodeKind::StringLiteral });
return nullptr; return nullptr;
} }
} }
@ -603,7 +674,12 @@ finish:
std::vector<Expression*> Args; std::vector<Expression*> Args;
for (;;) { for (;;) {
auto T1 = Tokens.peek(); auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || T1->getKind() == NodeKind::Comma || ExprOperators.isInfix(T1)) { if (T1->getKind() == NodeKind::LineFoldEnd
|| T1->getKind() == NodeKind::RParen
|| T1->getKind() == NodeKind::RBrace
|| T1->getKind() == NodeKind::BlockStart
|| T1->getKind() == NodeKind::Comma
|| ExprOperators.isInfix(T1)) {
break; break;
} }
auto Arg = parsePrimitiveExpression(); auto Arg = parsePrimitiveExpression();

View file

@ -28,7 +28,6 @@ namespace bolt {
return false; return false;
} }
switch (Kind) { switch (Kind) {
case TypeIndexKind::ConArg:
case TypeIndexKind::ArrowParamType: case TypeIndexKind::ArrowParamType:
case TypeIndexKind::TupleElement: case TypeIndexKind::TupleElement:
return I == Other.I; return I == Other.I;
@ -41,6 +40,9 @@ namespace bolt {
switch (Kind) { switch (Kind) {
case TypeIndexKind::End: case TypeIndexKind::End:
break; break;
case TypeIndexKind::AppOpType:
Kind = TypeIndexKind::AppArgType;
break;
case TypeIndexKind::ArrowParamType: case TypeIndexKind::ArrowParamType:
{ {
auto Arrow = llvm::cast<TArrow>(Ty); auto Arrow = llvm::cast<TArrow>(Ty);
@ -51,19 +53,16 @@ namespace bolt {
} }
break; break;
} }
case TypeIndexKind::FieldType:
Kind = TypeIndexKind::FieldRestType;
break;
case TypeIndexKind::FieldRestType:
case TypeIndexKind::TupleIndexType:
case TypeIndexKind::PresentType:
case TypeIndexKind::AppArgType:
case TypeIndexKind::ArrowReturnType: case TypeIndexKind::ArrowReturnType:
Kind = TypeIndexKind::End; Kind = TypeIndexKind::End;
break; break;
case TypeIndexKind::ConArg:
{
auto Con = llvm::cast<TCon>(Ty);
if (I+1 < Con->Args.size()) {
++I;
} else {
Kind = TypeIndexKind::End;
}
break;
}
case TypeIndexKind::TupleElement: case TypeIndexKind::TupleElement:
{ {
auto Tuple = llvm::cast<TTuple>(Ty); auto Tuple = llvm::cast<TTuple>(Ty);
@ -77,6 +76,95 @@ namespace bolt {
} }
} }
Type* Type::rewrite(std::function<Type*(Type*)> Fn, bool Recursive) {
auto Ty2 = Fn(this);
if (!Recursive && this != Ty2) {
return Ty2;
}
switch (Kind) {
case TypeKind::Var:
return Ty2;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(Ty2);
bool Changed = false;
std::vector<Type*> NewParamTypes;
for (auto Ty: Arrow->ParamTypes) {
auto NewParamType = Ty->rewrite(Fn);
if (NewParamType != Ty) {
Changed = true;
}
NewParamTypes.push_back(NewParamType);
}
auto NewRetTy = Arrow->ReturnType->rewrite(Fn);
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamTypes, NewRetTy) : Ty2;
}
case TypeKind::Con:
return Ty2;
case TypeKind::App:
{
auto App = static_cast<TApp*>(Ty2);
auto NewOp = App->Op->rewrite(Fn);
auto NewArg = App->Arg->rewrite(Fn);
if (NewOp == App->Op && NewArg == App->Arg) {
return App;
}
return new TApp(NewOp, NewArg);
}
case TypeKind::TupleIndex:
{
auto Tuple = static_cast<TTupleIndex*>(Ty2);
auto NewTy = Tuple->Ty->rewrite(Fn);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(Ty2);
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = Ty->rewrite(Fn);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : Ty2;
}
case TypeKind::Nil:
return Ty2;
case TypeKind::Absent:
return Ty2;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(Ty2);
bool Changed = false;
auto NewTy = Field->Ty->rewrite(Fn);
if (NewTy != Field->Ty) {
Changed = true;
}
auto NewRestTy = Field->RestTy->rewrite(Fn);
if (NewRestTy != Field->RestTy) {
Changed = true;
}
return Changed ? new TField(Field->Name, NewTy, NewRestTy) : Ty2;
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(Ty2);
auto NewTy = Present->Ty->rewrite(Fn);
if (NewTy == Present->Ty) {
return Ty2;
}
return new TPresent(NewTy);
}
}
}
void Type::addTypeVars(TVSet& TVs) { void Type::addTypeVars(TVSet& TVs) {
switch (Kind) { switch (Kind) {
case TypeKind::Var: case TypeKind::Var:
@ -92,11 +180,12 @@ namespace bolt {
break; break;
} }
case TypeKind::Con: case TypeKind::Con:
break;
case TypeKind::App:
{ {
auto Con = static_cast<TCon*>(this); auto App = static_cast<TApp*>(this);
for (auto Ty: Con->Args) { App->Op->addTypeVars(TVs);
Ty->addTypeVars(TVs); App->Arg->addTypeVars(TVs);
}
break; break;
} }
case TypeKind::TupleIndex: case TypeKind::TupleIndex:
@ -113,6 +202,23 @@ namespace bolt {
} }
break; break;
} }
case TypeKind::Nil:
break;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(this);
Field->Ty->addTypeVars(TVs);
Field->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(this);
Present->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Absent:
break;
} }
} }
@ -131,14 +237,11 @@ namespace bolt {
return Arrow->ReturnType->hasTypeVar(TV); return Arrow->ReturnType->hasTypeVar(TV);
} }
case TypeKind::Con: case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
for (auto Ty: Con->Args) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false; return false;
case TypeKind::App:
{
auto App = static_cast<TApp*>(this);
return App->Op->hasTypeVar(TV) && App->Arg->hasTypeVar(TV);
} }
case TypeKind::TupleIndex: case TypeKind::TupleIndex:
{ {
@ -155,173 +258,181 @@ namespace bolt {
} }
return false; return false;
} }
case TypeKind::Nil:
return false;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(this);
return Field->Ty->hasTypeVar(TV) || Field->RestTy->hasTypeVar(TV);
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(this);
return Present->Ty->hasTypeVar(TV);
}
case TypeKind::Absent:
return false;
} }
} }
Type* Type::solve() {
return rewrite([](auto Ty) {
if (Ty->getKind() == TypeKind::Var) {
return static_cast<TVar*>(Ty)->find();
}
return Ty;
});
}
Type* Type::substitute(const TVSub &Sub) { Type* Type::substitute(const TVSub &Sub) {
switch (Kind) { return rewrite([&](auto Ty) {
case TypeKind::Var: if (llvm::isa<TVar>(Ty)) {
{ auto TV = static_cast<TVar*>(Ty);
auto TV = static_cast<TVar*>(this);
auto Match = Sub.find(TV); auto Match = Sub.find(TV);
return Match != Sub.end() ? Match->second->substitute(Sub) : this; return Match != Sub.end() ? Match->second->substitute(Sub) : Ty;
}
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
bool Changed = false;
std::vector<Type*> NewParamTypes;
for (auto Ty: Arrow->ParamTypes) {
auto NewParamType = Ty->substitute(Sub);
if (NewParamType != Ty) {
Changed = true;
}
NewParamTypes.push_back(NewParamType);
}
auto NewRetTy = Arrow->ReturnType->substitute(Sub) ;
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
bool Changed = false;
std::vector<Type*> NewArgs;
for (auto Arg: Con->Args) {
auto NewArg = Arg->substitute(Sub);
if (NewArg != Arg) {
Changed = true;
}
NewArgs.push_back(NewArg);
}
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
}
case TypeKind::TupleIndex:
{
auto Tuple = static_cast<TTupleIndex*>(this);
auto NewTy = Tuple->Ty->substitute(Sub);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = Ty->substitute(Sub);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : this;
}
} }
return Ty;
});
} }
Type* Type::resolve(const TypeIndex& Index) const noexcept { Type* Type::resolve(const TypeIndex& Index) const noexcept {
switch (Index.Kind) { switch (Index.Kind) {
case TypeIndexKind::ConArg: case TypeIndexKind::PresentType:
return llvm::cast<TCon>(this)->Args[Index.I]; return llvm::cast<TPresent>(this)->Ty;
case TypeIndexKind::AppOpType:
return llvm::cast<TApp>(this)->Op;
case TypeIndexKind::AppArgType:
return llvm::cast<TApp>(this)->Arg;
case TypeIndexKind::TupleIndexType:
return llvm::cast<TTupleIndex>(this)->Ty;
case TypeIndexKind::TupleElement: case TypeIndexKind::TupleElement:
return llvm::cast<TTuple>(this)->ElementTypes[Index.I]; return llvm::cast<TTuple>(this)->ElementTypes[Index.I];
case TypeIndexKind::ArrowParamType: case TypeIndexKind::ArrowParamType:
return llvm::cast<TArrow>(this)->ParamTypes[Index.I]; return llvm::cast<TArrow>(this)->ParamTypes[Index.I];
case TypeIndexKind::ArrowReturnType: case TypeIndexKind::ArrowReturnType:
return llvm::cast<TArrow>(this)->ReturnType; return llvm::cast<TArrow>(this)->ReturnType;
case TypeIndexKind::FieldType:
return llvm::cast<TField>(this)->Ty;
case TypeIndexKind::FieldRestType:
return llvm::cast<TField>(this)->RestTy;
case TypeIndexKind::End: case TypeIndexKind::End:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
bool Type::operator==(const Type& Other) const noexcept { // bool Type::operator==(const Type& Other) const noexcept {
switch (Kind) { // switch (Kind) {
case TypeKind::Var: // case TypeKind::Var:
if (Other.Kind != TypeKind::Var) { // if (Other.Kind != TypeKind::Var) {
return false; // return false;
} // }
return static_cast<const TVar*>(this)->Id == static_cast<const TVar&>(Other).Id; // return static_cast<const TVar*>(this)->Id == static_cast<const TVar&>(Other).Id;
case TypeKind::Tuple: // case TypeKind::Tuple:
{ // {
if (Other.Kind != TypeKind::Tuple) { // if (Other.Kind != TypeKind::Tuple) {
return false; // return false;
} // }
auto A = static_cast<const TTuple&>(*this); // auto A = static_cast<const TTuple&>(*this);
auto B = static_cast<const TTuple&>(Other); // auto B = static_cast<const TTuple&>(Other);
if (A.ElementTypes.size() != B.ElementTypes.size()) { // if (A.ElementTypes.size() != B.ElementTypes.size()) {
return false; // return false;
} // }
for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) { // for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) {
if (*T1 != *T2) { // if (*T1 != *T2) {
return false; // return false;
} // }
} // }
return true; // return true;
} // }
case TypeKind::TupleIndex: // case TypeKind::TupleIndex:
{ // {
if (Other.Kind != TypeKind::TupleIndex) { // if (Other.Kind != TypeKind::TupleIndex) {
return false; // return false;
} // }
auto A = static_cast<const TTupleIndex&>(*this); // auto A = static_cast<const TTupleIndex&>(*this);
auto B = static_cast<const TTupleIndex&>(Other); // auto B = static_cast<const TTupleIndex&>(Other);
return A.I == B.I && *A.Ty == *B.Ty; // return A.I == B.I && *A.Ty == *B.Ty;
} // }
case TypeKind::Con: // case TypeKind::Con:
{ // {
if (Other.Kind != TypeKind::Con) { // if (Other.Kind != TypeKind::Con) {
return false; // return false;
} // }
auto A = static_cast<const TCon&>(*this); // auto A = static_cast<const TCon&>(*this);
auto B = static_cast<const TCon&>(Other); // auto B = static_cast<const TCon&>(Other);
if (A.Id != B.Id) { // if (A.Id != B.Id) {
return false; // return false;
} // }
if (A.Args.size() != B.Args.size()) { // if (A.Args.size() != B.Args.size()) {
return false; // return false;
} // }
for (auto [T1, T2]: zen::zip(A.Args, B.Args)) { // for (auto [T1, T2]: zen::zip(A.Args, B.Args)) {
if (*T1 != *T2) { // if (*T1 != *T2) {
return false; // return false;
} // }
} // }
return true; // return true;
} // }
case TypeKind::Arrow: // case TypeKind::Arrow:
{ // {
// FIXME Do we really need to 'curry' this type? // if (Other.Kind != TypeKind::Arrow) {
if (Other.Kind != TypeKind::Arrow) { // return false;
return false; // }
} // auto A = static_cast<const TArrow&>(*this);
auto A = static_cast<const TArrow&>(*this); // auto B = static_cast<const TArrow&>(Other);
auto B = static_cast<const TArrow&>(Other); // /* ArrowCursor C1 { &A }; */
/* ArrowCursor C1 { &A }; */ // /* ArrowCursor C2 { &B }; */
/* ArrowCursor C2 { &B }; */ // /* for (;;) { */
/* for (;;) { */ // /* auto T1 = C1.next(); */
/* auto T1 = C1.next(); */ // /* auto T2 = C2.next(); */
/* auto T2 = C2.next(); */ // /* if (T1 == nullptr && T2 == nullptr) { */
/* if (T1 == nullptr && T2 == nullptr) { */ // /* break; */
/* break; */ // /* } */
/* } */ // /* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */
/* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */ // /* return false; */
/* return false; */ // /* } */
/* } */ // /* } */
/* } */ // if (A.ParamTypes.size() != B.ParamTypes.size()) {
if (A.ParamTypes.size() != B.ParamTypes.size()) { // return false;
return false; // }
} // for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) {
for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) { // if (*T1 != *T2) {
if (*T1 != *T2) { // return false;
return false; // }
} // }
} // return A.ReturnType != B.ReturnType;
return A.ReturnType != B.ReturnType; // }
} // case TypeKind::Absent:
} // if (Other.Kind != TypeKind::Absent) {
} // return false;
// }
// return true;
// case TypeKind::Nil:
// if (Other.Kind != TypeKind::Nil) {
// return false;
// }
// return true;
// case TypeKind::Present:
// {
// if (Other.Kind != TypeKind::Present) {
// return false;
// }
// auto A = static_cast<const TPresent&>(*this);
// auto B = static_cast<const TPresent&>(Other);
// return *A.Ty == *B.Ty;
// }
// case TypeKind::Field:
// {
// if (Other.Kind != TypeKind::Field) {
// return false;
// }
// auto A = static_cast<const TField&>(*this);
// auto B = static_cast<const TField&>(Other);
// return *A.Ty == *B.Ty && *A.RestTy == *B.RestTy;
// }
// }
// }
TypeIterator Type::begin() { TypeIterator Type::begin() {
return TypeIterator { this, getStartIndex() }; return TypeIterator { this, getStartIndex() };
@ -333,14 +444,6 @@ namespace bolt {
TypeIndex Type::getStartIndex() { TypeIndex Type::getStartIndex() {
switch (Kind) { switch (Kind) {
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
if (Con->Args.empty()) {
return TypeIndex(TypeIndexKind::End);
}
return TypeIndex::forConArg(0);
}
case TypeKind::Arrow: case TypeKind::Arrow:
{ {
auto Arrow = static_cast<TArrow*>(this); auto Arrow = static_cast<TArrow*>(this);
@ -357,6 +460,8 @@ namespace bolt {
} }
return TypeIndex::forTupleElement(0); return TypeIndex::forTupleElement(0);
} }
case TypeKind::Field:
return TypeIndex::forFieldType();
default: default:
return TypeIndex(TypeIndexKind::End); return TypeIndex(TypeIndexKind::End);
} }
@ -366,4 +471,25 @@ namespace bolt {
return TypeIndex(TypeIndexKind::End); return TypeIndex(TypeIndexKind::End);
} }
inline Type* TVar::find() {
TVar* Curr = this;
for (;;) {
auto Keep = Curr->Parent;
if (Keep->getKind() != TypeKind::Var || Keep == Curr) {
return Keep;
}
auto TV = static_cast<TVar*>(Keep);
Curr->Parent = TV->Parent;
Curr = TV;
}
}
void TVar::set(Type* Ty) {
auto Root = find();
// It is not possible to set a solution twice.
ZEN_ASSERT(Root->getKind() == TypeKind::Var);
static_cast<TVar*>(Root)->Parent = Ty;
}
} }