Major update to code base

- Add partial support for extensible records
 - Rewrite unifier in Checker.cc
 - Make use of union/find instead of a HashMap for type variables
 - Enhance diagnostic messages
 - Add a variant type
 - Add application types (TApp)
 - Some smaller bugfixes
This commit is contained in:
Sam Vervaeck 2023-05-29 20:37:23 +02:00
parent dfaa91c9b6
commit 6bd8ecff39
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
11 changed files with 1748 additions and 612 deletions

View file

@ -1497,6 +1497,9 @@ namespace bolt {
Fields(Fields),
RBrace(RBrace) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class Statement : public Node {
@ -1800,6 +1803,7 @@ namespace bolt {
class PubKeyword* PubKeyword;
class StructKeyword* StructKeyword;
IdentifierAlt* Name;
std::vector<VarTypeExpression*> Vars;
class BlockStart* BlockStart;
std::vector<RecordDeclarationField*> Fields;
@ -1807,12 +1811,14 @@ namespace bolt {
class PubKeyword* PubKeyword,
class StructKeyword* StructKeyword,
IdentifierAlt* Name,
std::vector<VarTypeExpression*> Vars,
class BlockStart* BlockStart,
std::vector<RecordDeclarationField*> Fields
): Node(NodeKind::RecordDeclaration),
PubKeyword(PubKeyword),
StructKeyword(StructKeyword),
Name(Name),
Vars(Vars),
BlockStart(BlockStart),
Fields(Fields) {}

View file

@ -166,6 +166,9 @@ namespace bolt {
class Checker {
friend class Unifier;
friend class UnificationFrame;
const LanguageConfig& Config;
DiagnosticEngine& DE;
@ -178,14 +181,10 @@ namespace bolt {
Graph<Node*> RefGraph;
std::unordered_map<Node*, InferContext*> CallGraph;
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
std::vector<InferContext*> Contexts;
TVSub Solution;
/**
* The queue that is used during solving to store any unsolved constraints.
*/
@ -208,15 +207,14 @@ namespace bolt {
Type* inferTypeExpression(TypeExpression* TE);
Type* inferLiteral(Literal* Lit);
void inferBindings(Pattern* Pattern, Type* T, ConstraintSet* Constraints, TVSet* TVs);
void inferBindings(Pattern* Pattern, Type* T);
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);
void infer(Node* node);
void inferLetDeclaration(LetDeclaration* N);
Constraint* convertToConstraint(ConstraintExpression* C);
TCon* createPrimConType();
TCon* createConType(ByteString Name);
TVar* createTypeVar();
TVarRigid* createRigidVar(ByteString Name);
InferContext* createInferContext(TVSet* TVs = new TVSet, ConstraintSet* Constraints = new ConstraintSet);
@ -239,8 +237,6 @@ namespace bolt {
*/
Type* lookupMono(ByteString Name);
InferContext* lookupCall(Node* Source, SymbolPath Path);
/**
* Get the return type for the current context. If none could be found, the program will abort.
*/
@ -252,10 +248,6 @@ namespace bolt {
void propagateClasses(TypeclassContext& Classes, Type* Ty);
void propagateClassTycon(TypeclassId& Class, TCon* Ty);
Type* simplify(Type* Ty);
Type* find(Type* Ty);
/**
* Assign a type to a unification variable.
*
@ -268,18 +260,21 @@ namespace bolt {
*/
void join(TVar* A, Type* B);
// Unification parameters
Type* OrigLeft;
Type* OrigRight;
TypePath LeftPath;
TypePath RightPath;
ByteString CurrentFieldName;
Node* Source;
bool unify(Type* A, Type* B);
void unifyError();
void solveCEqual(CEqual* C);
void solve(Constraint* Constraint, TVSub& Solution);
void solve(Constraint* Constraint);
void populate(SourceFile* SF);
@ -293,17 +288,22 @@ namespace bolt {
Checker(const LanguageConfig& Config, DiagnosticEngine& DE);
/**
* \internal
*/
Type* simplifyType(Type* Ty);
void check(SourceFile* SF);
inline Type* getBoolType() {
inline Type* getBoolType() const {
return BoolType;
}
inline Type* getStringType() {
inline Type* getStringType() const {
return StringType;
}
inline Type* getIntType() {
inline Type* getIntType() const {
return IntType;
}

View file

@ -6,6 +6,8 @@
#include <iostream>
#include "bolt/ByteString.hpp"
#include "bolt/CST.hpp"
#include "bolt/Type.hpp"
namespace bolt {
@ -60,6 +62,98 @@ namespace bolt {
Magenta,
};
enum StyleFlags : unsigned {
StyleFlags_None = 0,
StyleFlags_Bold = 1 << 0,
StyleFlags_Underline = 1 << 1,
StyleFlags_Italic = 1 << 2,
};
class Style {
unsigned Flags = StyleFlags_None;
Color FgColor = Color::None;
Color BgColor = Color::None;
public:
Color getForegroundColor() const noexcept {
return FgColor;
}
Color getBackgroundColor() const noexcept {
return BgColor;
}
void setForegroundColor(Color NewColor) noexcept {
FgColor = NewColor;
}
void setBackgroundColor(Color NewColor) noexcept {
BgColor = NewColor;
}
bool hasForegroundColor() const noexcept {
return FgColor != Color::None;
}
bool hasBackgroundColor() const noexcept {
return BgColor != Color::None;
}
void clearForegroundColor() noexcept {
FgColor = Color::None;
}
void clearBackgroundColor() noexcept {
BgColor = Color::None;
}
bool isUnderline() const noexcept {
return Flags & StyleFlags_Underline;
}
bool isItalic() const noexcept {
return Flags & StyleFlags_Italic;
}
bool isBold() const noexcept {
return Flags & StyleFlags_Bold;
}
void setUnderline(bool Enable) noexcept {
if (Enable) {
Flags |= StyleFlags_Underline;
} else {
Flags &= ~StyleFlags_Underline;
}
}
void setItalic(bool Enable) noexcept {
if (Enable) {
Flags |= StyleFlags_Italic;
} else {
Flags &= ~StyleFlags_Italic;
}
}
void setBold(bool Enable) noexcept {
if (Enable) {
Flags |= StyleFlags_Bold;
} else {
Flags &= ~StyleFlags_Bold;
}
}
void reset() noexcept {
FgColor = Color::None;
BgColor = Color::None;
Flags = 0;
}
};
/**
* Prints any diagnostic message that was added to it to the console.
*/
@ -67,8 +161,12 @@ namespace bolt {
std::ostream& Out;
Style ActiveStyle;
void setForegroundColor(Color C);
void setBackgroundColor(Color C);
void applyStyles();
void setBold(bool Enable);
void setItalic(bool Enable);
void setUnderline(bool Enable);
@ -99,6 +197,7 @@ namespace bolt {
void writePrefix(const Diagnostic& D);
void writeBinding(const ByteString& Name);
void writeType(std::size_t I);
void writeType(const Type* Ty, const TypePath& Underline);
void writeType(const Type* Ty);
void writeLoc(const TextFile& File, const TextLoc& Loc);
void writeTypeclassName(const ByteString& Name);

View file

@ -1,6 +1,7 @@
#pragma once
#include <cwchar>
#include <vector>
#include <stdexcept>
#include <memory>
@ -23,6 +24,7 @@ namespace bolt {
ClassNotFound,
TupleIndexOutOfRange,
InvalidTypeToTypeclass,
FieldNotFound,
};
class Diagnostic : std::runtime_error {
@ -88,14 +90,14 @@ namespace bolt {
class UnificationErrorDiagnostic : public Diagnostic {
public:
Type* Left;
Type* Right;
Type* OrigLeft;
Type* OrigRight;
TypePath LeftPath;
TypePath RightPath;
Node* Source;
inline UnificationErrorDiagnostic(Type* Left, Type* Right, TypePath LeftPath, TypePath RightPath, Node* Source):
Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {}
inline UnificationErrorDiagnostic(Type* OrigLeft, Type* OrigRight, TypePath LeftPath, TypePath RightPath, Node* Source):
Diagnostic(DiagnosticKind::UnificationError), OrigLeft(OrigLeft), OrigRight(OrigRight), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {}
inline Node* getNode() const override {
return Source;
@ -171,4 +173,17 @@ namespace bolt {
};
class FieldNotFoundDiagnostic : public Diagnostic {
public:
ByteString Name;
Type* Ty;
TypePath Path;
Node* Source;
inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source):
Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {}
};
}

View file

@ -82,6 +82,7 @@ namespace bolt {
MatchExpression* parseMatchExpression();
Expression* parseMemberExpression();
RecordExpression* parseRecordExpression();
Expression* parsePrimitiveExpression();
ConstraintExpression* parseConstraintExpression();

View file

@ -1,6 +1,8 @@
#pragma once
#include <functional>
#include <type_traits>
#include <vector>
#include <unordered_map>
#include <unordered_set>
@ -28,10 +30,15 @@ namespace bolt {
};
enum class TypeIndexKind {
AppOpType,
AppArgType,
ArrowParamType,
ArrowReturnType,
ConArg,
TupleElement,
FieldType,
FieldRestType,
TupleIndexType,
PresentType,
End,
};
@ -59,22 +66,42 @@ namespace bolt {
void advance(const Type* Ty);
static TypeIndex forArrowReturnType() {
return { TypeIndexKind::ArrowReturnType };
static TypeIndex forFieldType() {
return { TypeIndexKind::FieldType };
}
static TypeIndex forFieldRest() {
return { TypeIndexKind::FieldRestType };
}
static TypeIndex forArrowParamType(std::size_t I) {
return { TypeIndexKind::ArrowParamType, I };
}
static TypeIndex forConArg(std::size_t I) {
return { TypeIndexKind::ConArg, I };
static TypeIndex forArrowReturnType() {
return { TypeIndexKind::ArrowReturnType };
}
static TypeIndex forTupleElement(std::size_t I) {
return { TypeIndexKind::TupleElement, I };
}
static TypeIndex forAppOpType() {
return { TypeIndexKind::AppOpType };
}
static TypeIndex forAppArgType() {
return { TypeIndexKind::AppArgType };
}
static TypeIndex forTupleIndexType() {
return { TypeIndexKind::TupleIndexType };
}
static TypeIndex forPresentType() {
return { TypeIndexKind::PresentType };
}
};
class TypeIterator {
@ -116,9 +143,14 @@ namespace bolt {
enum class TypeKind : unsigned char {
Var,
Con,
App,
Arrow,
Tuple,
TupleIndex,
Field,
Nil,
Absent,
Present,
};
class Type {
@ -146,8 +178,18 @@ namespace bolt {
return Out;
}
/**
* Rewrites the entire substructure of a type to another one.
*
* \param Recursive If true, a succesfull local rewritten type will be again
* rewriten until it encounters some terminals.
*/
Type* rewrite(std::function<Type*(Type*)> Fn, bool Recursive = false);
Type* substitute(const TVSub& Sub);
Type* solve();
TypeIterator begin();
TypeIterator end();
@ -176,11 +218,10 @@ namespace bolt {
public:
const size_t Id;
std::vector<Type*> Args;
ByteString DisplayName;
inline TCon(const size_t Id, std::vector<Type*> Args, ByteString DisplayName):
Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {}
inline TCon(const size_t Id, ByteString DisplayName):
Type(TypeKind::Con), Id(Id), DisplayName(DisplayName) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Con;
@ -188,12 +229,30 @@ namespace bolt {
};
class TApp : public Type {
public:
Type* Op;
Type* Arg;
inline TApp(Type* Op, Type* Arg):
Type(TypeKind::App), Op(Op), Arg(Arg) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::App;
}
};
enum class VarKind {
Rigid,
Unification,
};
class TVar : public Type {
Type* Parent = this;
public:
const size_t Id;
@ -208,6 +267,10 @@ namespace bolt {
return VK;
}
Type* find();
void set(Type* Ty);
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Var;
}
@ -272,6 +335,215 @@ namespace bolt {
};
class TNil : public Type {
public:
inline TNil():
Type(TypeKind::Nil) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Nil;
}
};
class TField : public Type {
public:
ByteString Name;
Type* Ty;
Type* RestTy;
inline TField(
ByteString Name,
Type* Ty,
Type* RestTy
): Type(TypeKind::Field),
Name(Name),
Ty(Ty),
RestTy(RestTy) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Field;
}
};
class TAbsent : public Type {
public:
inline TAbsent():
Type(TypeKind::Absent) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Absent;
}
};
class TPresent : public Type {
public:
Type* Ty;
inline TPresent(Type* Ty):
Type(TypeKind::Present), Ty(Ty) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Present;
}
};
template<bool IsConst>
class TypeVisitorBase {
protected:
template<typename T>
using C = std::conditional<IsConst, const T, T>::type;
virtual void enterType(C<Type>* Ty) {}
virtual void exitType(C<Type>* Ty) {}
virtual void visitVarType(C<TVar>* Ty) {
visitEachChild(Ty);
}
virtual void visitAppType(C<TApp>* Ty) {
visitEachChild(Ty);
}
virtual void visitPresentType(C<TPresent>* Ty) {
visitEachChild(Ty);
}
virtual void visitConType(C<TCon>* Ty) {
visitEachChild(Ty);
}
virtual void visitArrowType(C<TArrow>* Ty) {
visitEachChild(Ty);
}
virtual void visitTupleType(C<TTuple>* Ty) {
visitEachChild(Ty);
}
virtual void visitTupleIndexType(C<TTupleIndex>* Ty) {
visitEachChild(Ty);
}
virtual void visitAbsentType(C<TAbsent>* Ty) {
visitEachChild(Ty);
}
virtual void visitFieldType(C<TField>* Ty) {
visitEachChild(Ty);
}
virtual void visitNilType(C<TNil>* Ty) {
visitEachChild(Ty);
}
public:
void visitEachChild(C<Type>* Ty) {
switch (Ty->getKind()) {
case TypeKind::Var:
case TypeKind::Absent:
case TypeKind::Nil:
case TypeKind::Con:
break;
case TypeKind::Arrow:
{
auto Arrow = static_cast<C<TArrow>*>(Ty);
for (auto I = 0; I < Arrow->ParamTypes.size(); ++I) {
visit(Arrow->ParamTypes[I]);
}
visit(Arrow->ReturnType);
break;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<C<TTuple>*>(Ty);
for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) {
visit(Tuple->ElementTypes[I]);
}
break;
}
case TypeKind::App:
{
auto App = static_cast<C<TApp>*>(Ty);
visit(App->Op);
visit(App->Arg);
break;
}
case TypeKind::Field:
{
auto Field = static_cast<C<TField>*>(Ty);
visit(Field->Ty);
visit(Field->RestTy);
break;
}
case TypeKind::Present:
{
auto Present = static_cast<C<TPresent>*>(Ty);
visit(Present->Ty);
break;
}
case TypeKind::TupleIndex:
{
auto Index = static_cast<C<TTupleIndex>*>(Ty);
visit(Index->Ty);
break;
}
}
}
void visit(C<Type>* Ty) {
enterType(Ty);
switch (Ty->getKind()) {
case TypeKind::Present:
visitPresentType(static_cast<C<TPresent>*>(Ty));
break;
case TypeKind::Absent:
visitAbsentType(static_cast<C<TAbsent>*>(Ty));
break;
case TypeKind::Nil:
visitNilType(static_cast<C<TNil>*>(Ty));
break;
case TypeKind::Field:
visitFieldType(static_cast<C<TField>*>(Ty));
break;
case TypeKind::Con:
visitConType(static_cast<C<TCon>*>(Ty));
break;
case TypeKind::Arrow:
visitArrowType(static_cast<C<TArrow>*>(Ty));
break;
case TypeKind::Var:
visitVarType(static_cast<C<TVar>*>(Ty));
break;
case TypeKind::Tuple:
visitTupleType(static_cast<C<TTuple>*>(Ty));
break;
case TypeKind::App:
visitAppType(static_cast<C<TApp>*>(Ty));
break;
case TypeKind::TupleIndex:
visitTupleIndexType(static_cast<C<TTupleIndex>*>(Ty));
break;
}
exitType(Ty);
}
virtual ~TypeVisitorBase() {}
};
using TypeVisitor = TypeVisitorBase<false>;
using ConstTypeVisitor = TypeVisitorBase<true>;
// template<typename T>
// struct DerefHash {
// std::size_t operator()(const T& Value) const noexcept {

View file

@ -417,6 +417,22 @@ namespace bolt {
return BlockStart;
}
Token* RecordExpressionField::getFirstToken() const {
return Name;
}
Token* RecordExpressionField::getLastToken() const {
return E->getLastToken();
}
Token* RecordExpression::getFirstToken() const {
return LBrace;
}
Token* RecordExpression::getLastToken() const {
return RBrace;
}
Token* MemberExpression::getFirstToken() const {
return E->getFirstToken();
}

File diff suppressed because it is too large Load diff

View file

@ -13,6 +13,7 @@
#define ANSI_RESET "\u001b[0m"
#define ANSI_BOLD "\u001b[1m"
#define ANSI_ITALIC "\u001b[3m"
#define ANSI_UNDERLINE "\u001b[4m"
#define ANSI_REVERSED "\u001b[7m"
@ -107,6 +108,16 @@ namespace bolt {
return "'return'";
case NodeKind::TypeKeyword:
return "'type'";
case NodeKind::LetDeclaration:
return "a let-declaration";
case NodeKind::CallExpression:
return "a call-expression";
case NodeKind::InfixExpression:
return "an infix-expression";
case NodeKind::ReferenceExpression:
return "a function or variable reference";
case NodeKind::MatchExpression:
return "a match-expression";
default:
ZEN_UNREACHABLE
}
@ -151,16 +162,12 @@ namespace bolt {
case TypeKind::Con:
{
auto Y = static_cast<const TCon*>(Ty);
std::ostringstream Out;
if (!Y->DisplayName.empty()) {
Out << Y->DisplayName;
} else {
Out << "C" << Y->Id;
}
for (auto Arg: Y->Args) {
Out << " " << describe(Arg);
}
return Out.str();
return Y->DisplayName;
}
case TypeKind::App:
{
auto Y = static_cast<const TApp*>(Ty);
return describe(Y->Op) + " " + describe(Y->Arg);
}
case TypeKind::Tuple:
{
@ -182,6 +189,94 @@ namespace bolt {
auto Y = static_cast<const TTupleIndex*>(Ty);
return describe(Y->Ty) + "." + std::to_string(Y->I);
}
case TypeKind::Nil:
return "{}";
case TypeKind::Absent:
return "Abs";
case TypeKind::Present:
{
auto Y = static_cast<const TPresent*>(Ty);
return describe(Y->Ty);
}
case TypeKind::Field:
{
auto Y = static_cast<const TField*>(Ty);
std::ostringstream out;
out << "{ " << Y->Name << ": " << describe(Y->Ty);
Ty = Y->RestTy;
while (Ty->getKind() == TypeKind::Field) {
auto Y = static_cast<const TField*>(Ty);
out << "; " + Y->Name + ": " + describe(Y->Ty);
Ty = Y->RestTy;
}
if (Ty->getKind() != TypeKind::Nil) {
out << "; " + describe(Ty);
}
out << " }";
return out.str();
}
}
}
void writeForegroundANSI(Color C, std::ostream& Out) {
switch (C) {
case Color::None:
break;
case Color::Black:
Out << ANSI_FG_BLACK;
break;
case Color::White:
Out << ANSI_FG_WHITE;
break;
case Color::Red:
Out << ANSI_FG_RED;
break;
case Color::Yellow:
Out << ANSI_FG_YELLOW;
break;
case Color::Green:
Out << ANSI_FG_GREEN;
break;
case Color::Blue:
Out << ANSI_FG_BLUE;
break;
case Color::Cyan:
Out << ANSI_FG_CYAN;
break;
case Color::Magenta:
Out << ANSI_FG_MAGENTA;
break;
}
}
void writeBackgroundANSI(Color C, std::ostream& Out) {
switch (C) {
case Color::None:
break;
case Color::Black:
Out << ANSI_BG_BLACK;
break;
case Color::White:
Out << ANSI_BG_WHITE;
break;
case Color::Red:
Out << ANSI_BG_RED;
break;
case Color::Yellow:
Out << ANSI_BG_YELLOW;
break;
case Color::Green:
Out << ANSI_BG_GREEN;
break;
case Color::Blue:
Out << ANSI_BG_BLUE;
break;
case Color::Cyan:
Out << ANSI_BG_CYAN;
break;
case Color::Magenta:
Out << ANSI_BG_MAGENTA;
break;
}
}
@ -195,91 +290,84 @@ namespace bolt {
Out(Out) {}
void ConsoleDiagnostics::setForegroundColor(Color C) {
if (EnableColors) {
switch (C) {
case Color::None:
break;
case Color::Black:
Out << ANSI_FG_BLACK;
break;
case Color::White:
Out << ANSI_FG_WHITE;
break;
case Color::Red:
Out << ANSI_FG_RED;
break;
case Color::Yellow:
Out << ANSI_FG_YELLOW;
break;
case Color::Green:
Out << ANSI_FG_GREEN;
break;
case Color::Blue:
Out << ANSI_FG_BLUE;
break;
case Color::Cyan:
Out << ANSI_FG_CYAN;
break;
case Color::Magenta:
Out << ANSI_FG_MAGENTA;
break;
}
ActiveStyle.setForegroundColor(C);
if (!EnableColors) {
return;
}
writeForegroundANSI(C, Out);
}
void ConsoleDiagnostics::setBackgroundColor(Color C) {
if (EnableColors) {
switch (C) {
case Color::None:
break;
case Color::Black:
Out << ANSI_BG_BLACK;
break;
case Color::White:
Out << ANSI_BG_WHITE;
break;
case Color::Red:
Out << ANSI_BG_RED;
break;
case Color::Yellow:
Out << ANSI_BG_YELLOW;
break;
case Color::Green:
Out << ANSI_BG_GREEN;
break;
case Color::Blue:
Out << ANSI_BG_BLUE;
break;
case Color::Cyan:
Out << ANSI_BG_CYAN;
break;
case Color::Magenta:
Out << ANSI_BG_MAGENTA;
break;
}
ActiveStyle.setBackgroundColor(C);
if (!EnableColors) {
return;
}
if (C == Color::None) {
Out << ANSI_RESET;
applyStyles();
}
writeBackgroundANSI(C, Out);
}
void ConsoleDiagnostics::applyStyles() {
if (ActiveStyle.isBold()) {
Out << ANSI_BOLD;
}
if (ActiveStyle.isUnderline()) {
Out << ANSI_UNDERLINE;
}
if (ActiveStyle.isItalic()) {
Out << ANSI_ITALIC;
}
if (ActiveStyle.hasBackgroundColor()) {
setBackgroundColor(ActiveStyle.getBackgroundColor());
}
if (ActiveStyle.hasForegroundColor()) {
setForegroundColor(ActiveStyle.getForegroundColor());
}
}
void ConsoleDiagnostics::setBold(bool Enable) {
ActiveStyle.setBold(Enable);
if (!EnableColors) {
return;
}
if (Enable) {
Out << ANSI_BOLD;
} else {
Out << ANSI_RESET;
applyStyles();
}
}
void ConsoleDiagnostics::setItalic(bool Enable) {
ActiveStyle.setItalic(Enable);
if (!EnableColors) {
return;
}
if (Enable) {
// TODO
Out << ANSI_ITALIC;
} else {
Out << ANSI_RESET;
applyStyles();
}
}
void ConsoleDiagnostics::setUnderline(bool Enable) {
ActiveStyle.setItalic(Enable);
if (!EnableColors) {
return;
}
if (Enable) {
Out << ANSI_UNDERLINE;
} else {
Out << ANSI_RESET;
applyStyles();
}
}
void ConsoleDiagnostics::resetStyles() {
ActiveStyle.reset();
if (EnableColors) {
Out << ANSI_RESET;
}
@ -391,8 +479,159 @@ namespace bolt {
}
void ConsoleDiagnostics::writeType(const Type* Ty) {
TypePath Path;
writeType(Ty, Path);
}
void ConsoleDiagnostics::writeType(const Type* Ty, const TypePath& Underline) {
setForegroundColor(Color::Green);
write(describe(Ty));
class TypePrinter : public ConstTypeVisitor {
TypePath Path;
ConsoleDiagnostics& W;
const TypePath& Underline;
public:
TypePrinter(ConsoleDiagnostics& W, const TypePath& Underline):
W(W), Underline(Underline) {}
bool shouldUnderline() const {
return !Underline.empty() && Path == Underline;
}
void enterType(const Type* Ty) override {
if (shouldUnderline()) {
W.setUnderline(true);
}
}
void exitType(const Type* Ty) override {
if (shouldUnderline()) {
W.setUnderline(false);
}
}
void visitAppType(const TApp *Ty) override {
auto Y = static_cast<const TApp*>(Ty);
Path.push_back(TypeIndex::forAppOpType());
visit(Y->Op);
Path.pop_back();
W.write(" ");
Path.push_back(TypeIndex::forAppArgType());
visit(Y->Arg);
Path.pop_back();
}
void visitVarType(const TVar* Ty) override {
if (Ty->getVarKind() == VarKind::Rigid) {
W.write(static_cast<const TVarRigid*>(Ty)->Name);
return;
}
W.write("a");
W.write(Ty->Id);
}
void visitConType(const TCon *Ty) override {
W.write(Ty->DisplayName);
}
void visitArrowType(const TArrow* Ty) override {
W.write("(");
bool First = true;
std::size_t I = 0;
for (auto PT: Ty->ParamTypes) {
if (First) First = false;
else W.write(", ");
Path.push_back(TypeIndex::forArrowParamType(I++));
visit(PT);
Path.pop_back();
}
W.write(") -> ");
Path.push_back(TypeIndex::forArrowReturnType());
visit(Ty->ReturnType);
Path.pop_back();
}
void visitTupleType(const TTuple *Ty) override {
W.write("(");
if (Ty->ElementTypes.size()) {
auto Iter = Ty->ElementTypes.begin();
Path.push_back(TypeIndex::forTupleElement(0));
visit(*Iter++);
Path.pop_back();
std::size_t I = 1;
while (Iter != Ty->ElementTypes.end()) {
W.write(", ");
Path.push_back(TypeIndex::forTupleElement(I++));
visit(*Iter++);
Path.pop_back();
}
}
W.write(")");
}
void visitTupleIndexType(const TTupleIndex *Ty) override {
Path.push_back(TypeIndex::forTupleIndexType());
visit(Ty->Ty);
Path.pop_back();
W.write(".");
W.write(Ty->I);
}
void visitNilType(const TNil *Ty) override {
W.write("{}");
}
void visitAbsentType(const TAbsent *Ty) override {
W.write("Abs");
}
void visitPresentType(const TPresent *Ty) override {
Path.push_back(TypeIndex::forPresentType());
visit(Ty->Ty);
Path.pop_back();
}
void visitFieldType(const TField* Ty) override {
W.write("{ ");
W.write(Ty->Name);
W.write(": ");
Path.push_back(TypeIndex::forFieldType());
visit(Ty->Ty);
Path.pop_back();
auto Ty2 = Ty->RestTy;
Path.push_back(TypeIndex::forFieldRest());
std::size_t I = 1;
while (Ty2->getKind() == TypeKind::Field) {
auto Y = static_cast<const TField*>(Ty2);
W.write("; ");
W.write(Y->Name);
W.write(": ");
Path.push_back(TypeIndex::forFieldType());
visit(Y->Ty);
Path.pop_back();
Ty2 = Y->RestTy;
Path.push_back(TypeIndex::forFieldRest());
++I;
}
if (Ty2->getKind() != TypeKind::Nil) {
W.write("; ");
visit(Ty);
}
W.write(" }");
for (auto K = 0; K < I; K++) {
Path.pop_back();
}
}
};
TypePrinter P { *this, Underline };
P.visit(Ty);
resetStyles();
}
@ -533,40 +772,51 @@ namespace bolt {
case DiagnosticKind::UnificationError:
{
auto E = static_cast<const UnificationErrorDiagnostic&>(D);
auto Left = E.OrigLeft->resolve(E.LeftPath);
auto Right = E.OrigRight->resolve(E.RightPath);
writePrefix(E);
auto Left = E.Left->resolve(E.LeftPath);
auto Right = E.Right->resolve(E.RightPath);
write("the types ");
writeType(Left);
write(" and ");
writeType(Right);
write(" failed to match\n\n");
if (E.Source) {
writeNode(E.Source);
Out << "\n";
}
if (!E.LeftPath.empty()) {
setForegroundColor(Color::Yellow);
setBold(true);
write(" info: ");
resetStyles();
write("the type ");
writeType(Left);
write(" occurs in the full type ");
writeType(E.Left);
write("\n\n");
}
if (!E.RightPath.empty()) {
setForegroundColor(Color::Yellow);
setBold(true);
write(" info: ");
resetStyles();
write("the type ");
writeType(Right);
write(" occurs in the full type ");
writeType(E.Right);
write("\n\n");
}
setForegroundColor(Color::Yellow);
setBold(true);
write(" info: ");
resetStyles();
write("due to an equality constraint on ");
write(describe(E.Source->getKind()));
write(":\n\n");
write(" - left type ");
writeType(E.OrigLeft, E.LeftPath);
write("\n");
write(" - right type ");
writeType(E.OrigRight, E.RightPath);
write("\n\n");
writeNode(E.Source);
write("\n");
// if (E.Left != E.OrigLeft) {
// setForegroundColor(Color::Yellow);
// setBold(true);
// write(" info: ");
// resetStyles();
// write("the type ");
// writeType(E.Left);
// write(" occurs in the full type ");
// writeType(E.OrigLeft);
// write("\n\n");
// }
// if (E.Right != E.OrigRight) {
// setForegroundColor(Color::Yellow);
// setBold(true);
// write(" info: ");
// resetStyles();
// write("the type ");
// writeType(E.Right);
// write(" occurs in the full type ");
// writeType(E.OrigRight);
// write("\n\n");
// }
break;
}
@ -634,6 +884,18 @@ namespace bolt {
break;
}
case DiagnosticKind::FieldNotFound:
{
auto E = static_cast<const FieldNotFoundDiagnostic&>(D);
writePrefix(E);
write("the field '");
write(E.Name);
write("' was required in one type but not found in another\n\n");
writeNode(E.Source);
write("\n");
break;
}
}
}

View file

@ -473,6 +473,75 @@ after_tuple_element:
return new MatchExpression(static_cast<MatchKeyword*>(T0), Value, BlockStart, Cases);
}
RecordExpression* Parser::parseRecordExpression() {
auto LBrace = expectToken<class LBrace>();
if (!LBrace) {
return nullptr;
}
RBrace* RBrace;
auto T1 = Tokens.peek();
std::vector<std::tuple<RecordExpressionField*, Comma*>> Fields;
if (T1->getKind() == NodeKind::RBrace) {
Tokens.get();
RBrace = static_cast<class RBrace*>(T1);
} else {
for (;;) {
auto Name = expectToken<Identifier>();
if (!Name) {
LBrace->unref();
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
return nullptr;
}
auto Equals = expectToken<class Equals>();
if (!Equals) {
LBrace->unref();
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
return nullptr;
}
auto E = parseExpression();
if (!E) {
LBrace->unref();
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
Equals->unref();
return nullptr;
}
auto T2 = Tokens.peek();
if (T2->getKind() == NodeKind::Comma) {
Tokens.get();
Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, static_cast<Comma*>(T2)));
} else if (T2->getKind() == NodeKind::RBrace) {
Tokens.get();
RBrace = static_cast<class RBrace*>(T2);
Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, nullptr));
break;
} else {
DE.add<UnexpectedTokenDiagnostic>(File, T2, std::vector { NodeKind::Comma, NodeKind::RBrace });
LBrace->unref();
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
Equals->unref();
E->unref();
return nullptr;
}
}
}
return new RecordExpression { LBrace, Fields, RBrace };
}
Expression* Parser::parsePrimitiveExpression() {
auto T0 = Tokens.peek();
switch (T0->getKind()) {
@ -562,9 +631,11 @@ after_tuple_elements:
case NodeKind::StringLiteral:
Tokens.get();
return new ConstantExpression(static_cast<Literal*>(T0));
case NodeKind::LBrace:
return parseRecordExpression();
default:
// Tokens.get();
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen, NodeKind::IntegerLiteral, NodeKind::StringLiteral });
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen, NodeKind::LBrace, NodeKind::IntegerLiteral, NodeKind::StringLiteral });
return nullptr;
}
}
@ -603,7 +674,12 @@ finish:
std::vector<Expression*> Args;
for (;;) {
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || T1->getKind() == NodeKind::Comma || ExprOperators.isInfix(T1)) {
if (T1->getKind() == NodeKind::LineFoldEnd
|| T1->getKind() == NodeKind::RParen
|| T1->getKind() == NodeKind::RBrace
|| T1->getKind() == NodeKind::BlockStart
|| T1->getKind() == NodeKind::Comma
|| ExprOperators.isInfix(T1)) {
break;
}
auto Arg = parsePrimitiveExpression();

View file

@ -28,7 +28,6 @@ namespace bolt {
return false;
}
switch (Kind) {
case TypeIndexKind::ConArg:
case TypeIndexKind::ArrowParamType:
case TypeIndexKind::TupleElement:
return I == Other.I;
@ -41,6 +40,9 @@ namespace bolt {
switch (Kind) {
case TypeIndexKind::End:
break;
case TypeIndexKind::AppOpType:
Kind = TypeIndexKind::AppArgType;
break;
case TypeIndexKind::ArrowParamType:
{
auto Arrow = llvm::cast<TArrow>(Ty);
@ -51,19 +53,16 @@ namespace bolt {
}
break;
}
case TypeIndexKind::FieldType:
Kind = TypeIndexKind::FieldRestType;
break;
case TypeIndexKind::FieldRestType:
case TypeIndexKind::TupleIndexType:
case TypeIndexKind::PresentType:
case TypeIndexKind::AppArgType:
case TypeIndexKind::ArrowReturnType:
Kind = TypeIndexKind::End;
break;
case TypeIndexKind::ConArg:
{
auto Con = llvm::cast<TCon>(Ty);
if (I+1 < Con->Args.size()) {
++I;
} else {
Kind = TypeIndexKind::End;
}
break;
}
case TypeIndexKind::TupleElement:
{
auto Tuple = llvm::cast<TTuple>(Ty);
@ -77,6 +76,95 @@ namespace bolt {
}
}
Type* Type::rewrite(std::function<Type*(Type*)> Fn, bool Recursive) {
auto Ty2 = Fn(this);
if (!Recursive && this != Ty2) {
return Ty2;
}
switch (Kind) {
case TypeKind::Var:
return Ty2;
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(Ty2);
bool Changed = false;
std::vector<Type*> NewParamTypes;
for (auto Ty: Arrow->ParamTypes) {
auto NewParamType = Ty->rewrite(Fn);
if (NewParamType != Ty) {
Changed = true;
}
NewParamTypes.push_back(NewParamType);
}
auto NewRetTy = Arrow->ReturnType->rewrite(Fn);
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamTypes, NewRetTy) : Ty2;
}
case TypeKind::Con:
return Ty2;
case TypeKind::App:
{
auto App = static_cast<TApp*>(Ty2);
auto NewOp = App->Op->rewrite(Fn);
auto NewArg = App->Arg->rewrite(Fn);
if (NewOp == App->Op && NewArg == App->Arg) {
return App;
}
return new TApp(NewOp, NewArg);
}
case TypeKind::TupleIndex:
{
auto Tuple = static_cast<TTupleIndex*>(Ty2);
auto NewTy = Tuple->Ty->rewrite(Fn);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(Ty2);
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = Ty->rewrite(Fn);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : Ty2;
}
case TypeKind::Nil:
return Ty2;
case TypeKind::Absent:
return Ty2;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(Ty2);
bool Changed = false;
auto NewTy = Field->Ty->rewrite(Fn);
if (NewTy != Field->Ty) {
Changed = true;
}
auto NewRestTy = Field->RestTy->rewrite(Fn);
if (NewRestTy != Field->RestTy) {
Changed = true;
}
return Changed ? new TField(Field->Name, NewTy, NewRestTy) : Ty2;
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(Ty2);
auto NewTy = Present->Ty->rewrite(Fn);
if (NewTy == Present->Ty) {
return Ty2;
}
return new TPresent(NewTy);
}
}
}
void Type::addTypeVars(TVSet& TVs) {
switch (Kind) {
case TypeKind::Var:
@ -92,11 +180,12 @@ namespace bolt {
break;
}
case TypeKind::Con:
break;
case TypeKind::App:
{
auto Con = static_cast<TCon*>(this);
for (auto Ty: Con->Args) {
Ty->addTypeVars(TVs);
}
auto App = static_cast<TApp*>(this);
App->Op->addTypeVars(TVs);
App->Arg->addTypeVars(TVs);
break;
}
case TypeKind::TupleIndex:
@ -113,6 +202,23 @@ namespace bolt {
}
break;
}
case TypeKind::Nil:
break;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(this);
Field->Ty->addTypeVars(TVs);
Field->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(this);
Present->Ty->addTypeVars(TVs);
break;
}
case TypeKind::Absent:
break;
}
}
@ -131,14 +237,11 @@ namespace bolt {
return Arrow->ReturnType->hasTypeVar(TV);
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
for (auto Ty: Con->Args) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
case TypeKind::App:
{
auto App = static_cast<TApp*>(this);
return App->Op->hasTypeVar(TV) && App->Arg->hasTypeVar(TV);
}
case TypeKind::TupleIndex:
{
@ -155,173 +258,181 @@ namespace bolt {
}
return false;
}
case TypeKind::Nil:
return false;
case TypeKind::Field:
{
auto Field = static_cast<TField*>(this);
return Field->Ty->hasTypeVar(TV) || Field->RestTy->hasTypeVar(TV);
}
case TypeKind::Present:
{
auto Present = static_cast<TPresent*>(this);
return Present->Ty->hasTypeVar(TV);
}
case TypeKind::Absent:
return false;
}
}
Type* Type::solve() {
return rewrite([](auto Ty) {
if (Ty->getKind() == TypeKind::Var) {
return static_cast<TVar*>(Ty)->find();
}
return Ty;
});
}
Type* Type::substitute(const TVSub &Sub) {
switch (Kind) {
case TypeKind::Var:
{
auto TV = static_cast<TVar*>(this);
return rewrite([&](auto Ty) {
if (llvm::isa<TVar>(Ty)) {
auto TV = static_cast<TVar*>(Ty);
auto Match = Sub.find(TV);
return Match != Sub.end() ? Match->second->substitute(Sub) : this;
return Match != Sub.end() ? Match->second->substitute(Sub) : Ty;
}
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
bool Changed = false;
std::vector<Type*> NewParamTypes;
for (auto Ty: Arrow->ParamTypes) {
auto NewParamType = Ty->substitute(Sub);
if (NewParamType != Ty) {
Changed = true;
}
NewParamTypes.push_back(NewParamType);
}
auto NewRetTy = Arrow->ReturnType->substitute(Sub) ;
if (NewRetTy != Arrow->ReturnType) {
Changed = true;
}
return Changed ? new TArrow(NewParamTypes, NewRetTy) : this;
}
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
bool Changed = false;
std::vector<Type*> NewArgs;
for (auto Arg: Con->Args) {
auto NewArg = Arg->substitute(Sub);
if (NewArg != Arg) {
Changed = true;
}
NewArgs.push_back(NewArg);
}
return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this;
}
case TypeKind::TupleIndex:
{
auto Tuple = static_cast<TTupleIndex*>(this);
auto NewTy = Tuple->Ty->substitute(Sub);
return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple;
}
case TypeKind::Tuple:
{
auto Tuple = static_cast<TTuple*>(this);
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple->ElementTypes) {
auto NewElementType = Ty->substitute(Sub);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new TTuple(NewElementTypes) : this;
}
}
return Ty;
});
}
Type* Type::resolve(const TypeIndex& Index) const noexcept {
switch (Index.Kind) {
case TypeIndexKind::ConArg:
return llvm::cast<TCon>(this)->Args[Index.I];
case TypeIndexKind::PresentType:
return llvm::cast<TPresent>(this)->Ty;
case TypeIndexKind::AppOpType:
return llvm::cast<TApp>(this)->Op;
case TypeIndexKind::AppArgType:
return llvm::cast<TApp>(this)->Arg;
case TypeIndexKind::TupleIndexType:
return llvm::cast<TTupleIndex>(this)->Ty;
case TypeIndexKind::TupleElement:
return llvm::cast<TTuple>(this)->ElementTypes[Index.I];
case TypeIndexKind::ArrowParamType:
return llvm::cast<TArrow>(this)->ParamTypes[Index.I];
case TypeIndexKind::ArrowReturnType:
return llvm::cast<TArrow>(this)->ReturnType;
case TypeIndexKind::FieldType:
return llvm::cast<TField>(this)->Ty;
case TypeIndexKind::FieldRestType:
return llvm::cast<TField>(this)->RestTy;
case TypeIndexKind::End:
ZEN_UNREACHABLE
}
ZEN_UNREACHABLE
}
bool Type::operator==(const Type& Other) const noexcept {
switch (Kind) {
case TypeKind::Var:
if (Other.Kind != TypeKind::Var) {
return false;
}
return static_cast<const TVar*>(this)->Id == static_cast<const TVar&>(Other).Id;
case TypeKind::Tuple:
{
if (Other.Kind != TypeKind::Tuple) {
return false;
}
auto A = static_cast<const TTuple&>(*this);
auto B = static_cast<const TTuple&>(Other);
if (A.ElementTypes.size() != B.ElementTypes.size()) {
return false;
}
for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
case TypeKind::TupleIndex:
{
if (Other.Kind != TypeKind::TupleIndex) {
return false;
}
auto A = static_cast<const TTupleIndex&>(*this);
auto B = static_cast<const TTupleIndex&>(Other);
return A.I == B.I && *A.Ty == *B.Ty;
}
case TypeKind::Con:
{
if (Other.Kind != TypeKind::Con) {
return false;
}
auto A = static_cast<const TCon&>(*this);
auto B = static_cast<const TCon&>(Other);
if (A.Id != B.Id) {
return false;
}
if (A.Args.size() != B.Args.size()) {
return false;
}
for (auto [T1, T2]: zen::zip(A.Args, B.Args)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
case TypeKind::Arrow:
{
// FIXME Do we really need to 'curry' this type?
if (Other.Kind != TypeKind::Arrow) {
return false;
}
auto A = static_cast<const TArrow&>(*this);
auto B = static_cast<const TArrow&>(Other);
/* ArrowCursor C1 { &A }; */
/* ArrowCursor C2 { &B }; */
/* for (;;) { */
/* auto T1 = C1.next(); */
/* auto T2 = C2.next(); */
/* if (T1 == nullptr && T2 == nullptr) { */
/* break; */
/* } */
/* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */
/* return false; */
/* } */
/* } */
if (A.ParamTypes.size() != B.ParamTypes.size()) {
return false;
}
for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) {
if (*T1 != *T2) {
return false;
}
}
return A.ReturnType != B.ReturnType;
}
}
}
// bool Type::operator==(const Type& Other) const noexcept {
// switch (Kind) {
// case TypeKind::Var:
// if (Other.Kind != TypeKind::Var) {
// return false;
// }
// return static_cast<const TVar*>(this)->Id == static_cast<const TVar&>(Other).Id;
// case TypeKind::Tuple:
// {
// if (Other.Kind != TypeKind::Tuple) {
// return false;
// }
// auto A = static_cast<const TTuple&>(*this);
// auto B = static_cast<const TTuple&>(Other);
// if (A.ElementTypes.size() != B.ElementTypes.size()) {
// return false;
// }
// for (auto [T1, T2]: zen::zip(A.ElementTypes, B.ElementTypes)) {
// if (*T1 != *T2) {
// return false;
// }
// }
// return true;
// }
// case TypeKind::TupleIndex:
// {
// if (Other.Kind != TypeKind::TupleIndex) {
// return false;
// }
// auto A = static_cast<const TTupleIndex&>(*this);
// auto B = static_cast<const TTupleIndex&>(Other);
// return A.I == B.I && *A.Ty == *B.Ty;
// }
// case TypeKind::Con:
// {
// if (Other.Kind != TypeKind::Con) {
// return false;
// }
// auto A = static_cast<const TCon&>(*this);
// auto B = static_cast<const TCon&>(Other);
// if (A.Id != B.Id) {
// return false;
// }
// if (A.Args.size() != B.Args.size()) {
// return false;
// }
// for (auto [T1, T2]: zen::zip(A.Args, B.Args)) {
// if (*T1 != *T2) {
// return false;
// }
// }
// return true;
// }
// case TypeKind::Arrow:
// {
// if (Other.Kind != TypeKind::Arrow) {
// return false;
// }
// auto A = static_cast<const TArrow&>(*this);
// auto B = static_cast<const TArrow&>(Other);
// /* ArrowCursor C1 { &A }; */
// /* ArrowCursor C2 { &B }; */
// /* for (;;) { */
// /* auto T1 = C1.next(); */
// /* auto T2 = C2.next(); */
// /* if (T1 == nullptr && T2 == nullptr) { */
// /* break; */
// /* } */
// /* if (T1 == nullptr || T2 == nullptr || *T1 != *T2) { */
// /* return false; */
// /* } */
// /* } */
// if (A.ParamTypes.size() != B.ParamTypes.size()) {
// return false;
// }
// for (auto [T1, T2]: zen::zip(A.ParamTypes, B.ParamTypes)) {
// if (*T1 != *T2) {
// return false;
// }
// }
// return A.ReturnType != B.ReturnType;
// }
// case TypeKind::Absent:
// if (Other.Kind != TypeKind::Absent) {
// return false;
// }
// return true;
// case TypeKind::Nil:
// if (Other.Kind != TypeKind::Nil) {
// return false;
// }
// return true;
// case TypeKind::Present:
// {
// if (Other.Kind != TypeKind::Present) {
// return false;
// }
// auto A = static_cast<const TPresent&>(*this);
// auto B = static_cast<const TPresent&>(Other);
// return *A.Ty == *B.Ty;
// }
// case TypeKind::Field:
// {
// if (Other.Kind != TypeKind::Field) {
// return false;
// }
// auto A = static_cast<const TField&>(*this);
// auto B = static_cast<const TField&>(Other);
// return *A.Ty == *B.Ty && *A.RestTy == *B.RestTy;
// }
// }
// }
TypeIterator Type::begin() {
return TypeIterator { this, getStartIndex() };
@ -333,14 +444,6 @@ namespace bolt {
TypeIndex Type::getStartIndex() {
switch (Kind) {
case TypeKind::Con:
{
auto Con = static_cast<TCon*>(this);
if (Con->Args.empty()) {
return TypeIndex(TypeIndexKind::End);
}
return TypeIndex::forConArg(0);
}
case TypeKind::Arrow:
{
auto Arrow = static_cast<TArrow*>(this);
@ -357,6 +460,8 @@ namespace bolt {
}
return TypeIndex::forTupleElement(0);
}
case TypeKind::Field:
return TypeIndex::forFieldType();
default:
return TypeIndex(TypeIndexKind::End);
}
@ -366,4 +471,25 @@ namespace bolt {
return TypeIndex(TypeIndexKind::End);
}
inline Type* TVar::find() {
TVar* Curr = this;
for (;;) {
auto Keep = Curr->Parent;
if (Keep->getKind() != TypeKind::Var || Keep == Curr) {
return Keep;
}
auto TV = static_cast<TVar*>(Keep);
Curr->Parent = TV->Parent;
Curr = TV;
}
}
void TVar::set(Type* Ty) {
auto Root = find();
// It is not possible to set a solution twice.
ZEN_ASSERT(Root->getKind() == TypeKind::Var);
static_cast<TVar*>(Root)->Parent = Ty;
}
}