Switch to bidirectional type-checker and many more improvements

This commit is contained in:
Sam Vervaeck 2024-06-21 00:18:44 +02:00
parent c907885420
commit 5ba2aafc68
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
24 changed files with 1171 additions and 3690 deletions

View file

@ -1,11 +1,12 @@
cmake_minimum_required(VERSION 3.10)
cmake_minimum_required(VERSION 3.20)
project(Bolt C CXX)
set(CMAKE_CXX_STANDARD 20)
add_subdirectory(deps/zen EXCLUDE_FROM_ALL)
add_subdirectory(deps/llvm-project/llvm EXCLUDE_FROM_ALL)
set(ICU_DIR "${CMAKE_CURRENT_SOURCE_DIR}/build/icu/install")
set(ICU_CFLAGS "-DUNISTR_FROM_CHAR_EXPLICIT=explicit -DUNISTR_FROM_STRING_EXPLICIT=explicit -DU_NO_DEFAULT_INCLUDE_UTF_HEADERS=1 -DU_HIDE_OBSOLETE_UTF_OLD_H=1")
@ -17,7 +18,7 @@ if (CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo" OR CMAKE_BUILD_TYPE STREQUAL "Deb
set(BOLT_DEBUG ON)
endif()
find_package(LLVM 18.1.0 REQUIRED)
#find_package(LLVM 19.0 REQUIRED)
add_library(
BoltCore
@ -27,10 +28,11 @@ add_library(
src/ConsolePrinter.cc
src/Scanner.cc
src/Parser.cc
src/Types.cc
src/Type.cc
src/Checker.cc
src/Evaluator.cc
src/Scope.cc
src/Program.cc
)
target_link_directories(
BoltCore
@ -41,6 +43,7 @@ target_compile_options(
BoltCore
PUBLIC
-Werror
-fno-exceptions
${ICU_CFLAGS}
)
@ -68,16 +71,18 @@ add_library(
BoltLLVM
src/LLVMCodeGen.cc
)
llvm_map_components_to_libnames(llvm_libs support core irreader)
target_include_directories(BoltLLVM PRIVATE ${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
target_compile_definitions(BoltLLVM PRIVATE ${LLVM_DEFINITIONS_LIST})
target_link_libraries(
BoltLLVM
PUBLIC
BoltCore
${llvm_libs}
LLVMCore
LLVMTarget
)
target_include_directories(
BoltLLVM
PUBLIC
deps/llvm-project/llvm/include # FIXME this is a hack
${CMAKE_BINARY_DIR}/deps/llvm-project/llvm/include # FIXME this is a hack
)
add_executable(

View file

@ -1,24 +1,23 @@
#ifndef BOLT_CST_HPP
#define BOLT_CST_HPP
#include <cmath>
#include <cstdlib>
#include <limits>
#include <unordered_map>
#include <variant>
#include <vector>
#include <optional>
#include "bolt/Common.hpp"
#include "zen/config.hpp"
#include "bolt/Common.hpp"
#include "bolt/Integer.hpp"
#include "bolt/String.hpp"
#include "bolt/ByteString.hpp"
#include "bolt/Type.hpp"
namespace bolt {
class Type;
class InferContext;
class Token;
class SourceFile;
class Scope;
@ -1265,6 +1264,8 @@ public:
return Ty;
}
static bool classof(Node* N);
};
class TypeExpression : public TypedNode, AnnotationContainer {
@ -1273,6 +1274,19 @@ protected:
inline TypeExpression(NodeKind Kind, std::vector<Annotation*> Annotations = {}):
TypedNode(Kind), AnnotationContainer(Annotations) {}
public:
static bool classof(Node* N) {
return N->getKind() == NodeKind::ReferenceTypeExpression
|| N->getKind() == NodeKind::AppTypeExpression
|| N->getKind() == NodeKind::NestedTypeExpression
|| N->getKind() == NodeKind::ArrowTypeExpression
|| N->getKind() == NodeKind::VarTypeExpression
|| N->getKind() == NodeKind::TupleTypeExpression
|| N->getKind() == NodeKind::RecordTypeExpression
|| N->getKind() == NodeKind::QualifiedTypeExpression;
}
};
class ConstraintExpression : public Node {
@ -1740,6 +1754,21 @@ protected:
inline Expression(NodeKind Kind, std::vector<Annotation*> Annotations = {}):
TypedNode(Kind), AnnotationContainer(Annotations) {}
public:
static bool classof(Node* N) {
return N->getKind() == NodeKind::ReferenceExpression
|| N->getKind() == NodeKind::NestedExpression
|| N->getKind() == NodeKind::CallExpression
|| N->getKind() == NodeKind::TupleExpression
|| N->getKind() == NodeKind::InfixExpression
|| N->getKind() == NodeKind::RecordExpression
|| N->getKind() == NodeKind::MatchExpression
|| N->getKind() == NodeKind::MemberExpression
|| N->getKind() == NodeKind::LiteralExpression
|| N->getKind() == NodeKind::PrefixExpression;
}
};
class ReferenceExpression : public Expression {
@ -1780,8 +1809,6 @@ class MatchCase : public Node {
public:
InferContext* Ctx;
class Pattern* Pattern;
class RArrowAlt* RArrowAlt;
class Expression* Expression;
@ -2117,6 +2144,14 @@ protected:
inline Statement(NodeKind Type, std::vector<Annotation*> Annotations = {}):
Node(Type), AnnotationContainer(Annotations) {}
public:
static bool classof(Node* N) {
return N->getKind() == NodeKind::ExpressionStatement
|| N->getKind() == NodeKind::ReturnStatement
|| N->getKind() == NodeKind::IfStatement;
}
};
class ExpressionStatement : public Statement {
@ -2192,14 +2227,14 @@ class ReturnStatement : public Statement {
public:
class ReturnKeyword* ReturnKeyword;
class Expression* Expression;
Expression* E;
ReturnStatement(
class ReturnKeyword* ReturnKeyword,
class Expression* Expression
): Statement(NodeKind::ReturnStatement),
ReturnKeyword(ReturnKeyword),
Expression(Expression) {}
E(Expression) {}
ReturnStatement(
std::vector<Annotation*> Annotations,
@ -2207,11 +2242,19 @@ public:
class Expression* Expression
): Statement(NodeKind::ReturnStatement, Annotations),
ReturnKeyword(ReturnKeyword),
Expression(Expression) {}
E(Expression) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
bool hasExpression() const {
return E;
}
Expression* getExpression() {
return E;
}
};
class TypeAssert : public Node {
@ -2297,7 +2340,44 @@ public:
};
class FunctionDeclaration : public TypedNode, public AnnotationContainer {
class Declaration : public TypedNode, public AnnotationContainer {
std::optional<TypeScheme> Scm;
protected:
inline Declaration(NodeKind Kind, std::vector<Annotation*> Annotations = {}):
TypedNode(Kind), AnnotationContainer(Annotations) {}
public:
static bool classof(const Node* N) {
return N->getKind() == NodeKind::VariantDeclaration
|| N->getKind() == NodeKind::RecordDeclaration
|| N->getKind() == NodeKind::VariantDeclaration
|| N->getKind() == NodeKind::PrefixFunctionDeclaration
|| N->getKind() == NodeKind::InfixFunctionDeclaration
|| N->getKind() == NodeKind::SuffixFunctionDeclaration
|| N->getKind() == NodeKind::NamedFunctionDeclaration;
}
const TypeScheme& getScheme() const {
ZEN_ASSERT(Scm.has_value());
return *Scm;
}
bool hasScheme() const {
return Scm.has_value();
}
void setScheme(TypeScheme NewScm) {
Scm = NewScm;
}
};
class FunctionDeclaration : public Declaration {
Scope* TheScope = nullptr;
@ -2305,10 +2385,9 @@ public:
bool IsCycleActive = false;
bool Visited = false;
InferContext* Ctx;
FunctionDeclaration(NodeKind Kind, std::vector<Annotation*> Annotations = {}):
TypedNode(Kind), AnnotationContainer(Annotations) {}
Declaration(Kind, Annotations) {}
virtual bool isPublic() const = 0;
@ -2604,7 +2683,7 @@ public:
};
class VariableDeclaration : public TypedNode, public AnnotationContainer {
class VariableDeclaration : public Declaration {
public:
class PubKeyword* PubKeyword;
@ -2625,8 +2704,7 @@ public:
class Pattern* Pattern,
class TypeAssert* TypeAssert,
LetBody* Body
): TypedNode(NodeKind::VariableDeclaration),
AnnotationContainer(Annotations),
): Declaration(NodeKind::VariableDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
@ -2651,6 +2729,15 @@ public:
return N->getKind() == NodeKind::VariableDeclaration;
}
bool hasExpression() const {
return Body;
}
Expression* getExpression() {
ZEN_ASSERT(Body->getKind() == NodeKind::LetExprBody);
return static_cast<LetExprBody*>(Body)->Expression;
}
};
class InstanceDeclaration : public Node {
@ -2739,11 +2826,9 @@ public:
};
class RecordDeclaration : public Node {
class RecordDeclaration : public Declaration {
public:
InferContext* Ctx;
class PubKeyword* PubKeyword;
class StructKeyword* StructKeyword;
IdentifierAlt* Name;
@ -2758,7 +2843,7 @@ public:
std::vector<VarTypeExpression*> Vars,
class BlockStart* BlockStart,
std::vector<RecordDeclarationField*> Fields
): Node(NodeKind::RecordDeclaration),
): Declaration(NodeKind::RecordDeclaration),
PubKeyword(PubKeyword),
StructKeyword(StructKeyword),
Name(Name),
@ -2818,11 +2903,9 @@ public:
};
class VariantDeclaration : public Node {
class VariantDeclaration : public Declaration {
public:
InferContext* Ctx;
class PubKeyword* PubKeyword;
class EnumKeyword* EnumKeyword;
class IdentifierAlt* Name;
@ -2837,7 +2920,7 @@ public:
std::vector<VarTypeExpression*> TVs,
class BlockStart* BlockStart,
std::vector<VariantDeclarationMember*> Members
): Node(NodeKind::VariantDeclaration),
): Declaration(NodeKind::VariantDeclaration),
PubKeyword(PubKeyword),
EnumKeyword(EnumKeyword),
Name(Name),
@ -2857,7 +2940,6 @@ class SourceFile : public Node {
public:
TextFile File;
InferContext* Ctx;
std::vector<Node*> Elements;

View file

@ -1,9 +1,6 @@
#pragma once
#include "CST.hpp"
#include "zen/config.hpp"
#include "bolt/CST.hpp"
namespace bolt {
@ -18,6 +15,10 @@ public:
case NodeKind::name: \
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
#define BOLT_VISIT(node) static_cast<D*>(this)->visit(node)
#define BOLT_VISIT_SYMBOL(node) static_cast<D*>(this)->dispatchSymbol(node)
#define BOLT_VISIT_OPERATOR(node) static_cast<D*>(this)->dispatchOperator(node)
switch (N->getKind()) {
BOLT_GEN_CASE(VBar)
BOLT_GEN_CASE(Equals)
@ -123,13 +124,13 @@ public:
void dispatchSymbol(const Symbol& S) {
switch (S.getKind()) {
case NodeKind::Identifier:
visit(S.asIdentifier());
BOLT_VISIT(S.asIdentifier());
break;
case NodeKind::IdentifierAlt:
visit(S.asIdentifierAlt());
BOLT_VISIT(S.asIdentifierAlt());
break;
case NodeKind::WrappedOperator:
visit(S.asWrappedOperator());
BOLT_VISIT(S.asWrappedOperator());
break;
default:
ZEN_UNREACHABLE
@ -139,10 +140,10 @@ public:
void dispatchOperator(const Operator& O) {
switch (O.getKind()) {
case NodeKind::VBar:
visit(O.asVBar());
BOLT_VISIT(O.asVBar());
break;
case NodeKind::CustomOperator:
visit(O.asCustomOperator());
BOLT_VISIT(O.asCustomOperator());
break;
default:
ZEN_UNREACHABLE
@ -698,10 +699,6 @@ public:
}
}
#define BOLT_VISIT(node) static_cast<D*>(this)->visit(node)
#define BOLT_VISIT_SYMBOL(node) static_cast<D*>(this)->dispatchSymbol(node)
#define BOLT_VISIT_OPERATOR(node) static_cast<D*>(this)->dispatchOperator(node)
void visitEachChild(VBar* N) {
}
@ -1152,7 +1149,7 @@ public:
BOLT_VISIT(A);
}
BOLT_VISIT(N->ReturnKeyword);
BOLT_VISIT(N->Expression);
BOLT_VISIT(N->E);
}
void visitEachChild(IfStatement* N) {

View file

@ -3,340 +3,145 @@
#include <cstdlib>
#include <unordered_map>
#include <vector>
#include <deque>
#include <unordered_set>
#include "zen/tuple_hash.hpp"
#include "bolt/ByteString.hpp"
#include "bolt/Common.hpp"
#include "bolt/CST.hpp"
#include "bolt/DiagnosticEngine.hpp"
#include "bolt/Type.hpp"
#include "bolt/Support/Graph.hpp"
namespace bolt {
std::string describe(const Type* Ty); // For debugging only
enum class SymKind {
Type,
Var,
enum class ConstraintKind {
TypesEqual,
};
class DiagnosticEngine;
class Constraint {
class Constraint;
using ConstraintSet = std::vector<Constraint*>;
enum class SchemeKind : unsigned char {
Forall,
};
class Scheme {
const SchemeKind Kind;
ConstraintKind Kind;
protected:
inline Scheme(SchemeKind Kind):
Constraint(ConstraintKind Kind):
Kind(Kind) {}
public:
inline SchemeKind getKind() const noexcept {
inline ConstraintKind getKind() const {
return Kind;
}
virtual ~Scheme() {}
};
class Forall : public Scheme {
class CTypesEqual : public Constraint {
Type* A;
Type* B;
Node* Origin;
public:
TVSet* TVs;
ConstraintSet* Constraints;
class Type* Type;
CTypesEqual(Type* A, Type* B, Node* Origin):
Constraint(ConstraintKind::TypesEqual), A(A), B(B), Origin(Origin) {}
inline Forall(class Type* Type):
Scheme(SchemeKind::Forall), TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {}
Type* getLeft() const {
return A;
}
inline Forall(
TVSet* TVs,
ConstraintSet* Constraints,
class Type* Type
): Scheme(SchemeKind::Forall),
TVs(TVs),
Constraints(Constraints),
Type(Type) {}
Type* getRight() const {
return B;
}
static bool classof(const Scheme* Scm) {
return Scm->getKind() == SchemeKind::Forall;
Node* getOrigin() const {
return Origin;
}
};
class TypeEnv {
std::unordered_map<std::tuple<ByteString, SymKind>, Scheme*> Mapping;
TypeEnv* Parent;
std::unordered_map<std::tuple<ByteString, SymbolKind>, TypeScheme*> Mapping;
public:
Scheme* lookup(ByteString Name, SymKind Kind) {
auto Key = std::make_tuple(Name, Kind);
auto Match = Mapping.find(Key);
if (Match == Mapping.end()) {
return nullptr;
}
return Match->second;
}
TypeEnv(TypeEnv* Parent = nullptr):
Parent(Parent) {}
void add(ByteString Name, Scheme* Scm, SymKind Kind) {
auto Key = std::make_tuple(Name, Kind);
ZEN_ASSERT(!Mapping.count(Key))
// auto F = static_cast<Forall*>(Scm);
// std::cerr << Name << " : forall ";
// for (auto TV: *F->TVs) {
// std::cerr << describe(TV) << " ";
// }
// std::cerr << ". " << describe(F->Type) << "\n";
Mapping.emplace(Key, Scm);
}
void add(ByteString Name, Type* Ty, SymbolKind Kind);
void add(ByteString Name, TypeScheme* Ty, SymbolKind Kind);
bool hasVar(TVar* TV) const;
TypeScheme* lookup(ByteString Name, SymbolKind Kind);
};
enum class ConstraintKind {
Equal,
Field,
Many,
Empty,
};
class Constraint {
const ConstraintKind Kind;
public:
inline Constraint(ConstraintKind Kind):
Kind(Kind) {}
inline ConstraintKind getKind() const noexcept {
return Kind;
}
Constraint* substitute(const TVSub& Sub);
virtual ~Constraint() {}
};
class CEqual : public Constraint {
public:
Type* Left;
Type* Right;
Node* Source;
inline CEqual(Type* Left, Type* Right, Node* Source = nullptr):
Constraint(ConstraintKind::Equal), Left(Left), Right(Right), Source(Source) {}
};
class CField : public Constraint {
public:
Type* TupleTy;
size_t I;
Type* FieldTy;
Node* Source;
inline CField(Type* TupleTy, size_t I, Type* FieldTy, Node* Source = nullptr):
Constraint(ConstraintKind::Field), TupleTy(TupleTy), I(I), FieldTy(FieldTy), Source(Source) {}
};
class CMany : public Constraint {
public:
ConstraintSet& Elements;
inline CMany(ConstraintSet& Elements):
Constraint(ConstraintKind::Many), Elements(Elements) {}
};
class CEmpty : public Constraint {
public:
inline CEmpty():
Constraint(ConstraintKind::Empty) {}
};
using InferContextFlagsMask = unsigned;
class InferContext {
public:
/**
* A heap-allocated list of type variables that eventually will become part of a Forall scheme.
*/
TVSet* TVs;
/**
* A heap-allocated list of constraints that eventually will become part of a Forall scheme.
*/
ConstraintSet* Constraints;
TypeEnv Env;
Type* ReturnType = nullptr;
InferContext* Parent = nullptr;
};
using ConstraintSet = std::vector<Constraint*>;
class Checker {
friend class Unifier;
friend class UnificationFrame;
const LanguageConfig& Config;
DiagnosticEngine& DE;
size_t NextConTypeId = 0;
size_t NextTypeVarId = 0;
Type* BoolType;
Type* ListType;
Type* IntType;
Type* BoolType;
Type* StringType;
Type* UnitType;
Graph<Node*> RefGraph;
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
/// Inference context management
InferContext* ActiveContext;
InferContext& getContext();
void setContext(InferContext* Ctx);
void popContext();
void makeEqual(Type* A, Type* B, Node* Source);
void addConstraint(Constraint* Constraint);
/**
* Get the return type for the current context. If none could be found, the
* program will abort.
*/
Type* getReturnType();
/// Type inference
void forwardDeclare(Node* Node);
void forwardDeclareFunctionDeclaration(FunctionDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true);
Type* inferLiteral(Literal* Lit);
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);
void infer(Node* node);
void inferFunctionDeclaration(FunctionDeclaration* N);
void inferConstraintExpression(ConstraintExpression* C);
/// Factory methods
Type* createConType(ByteString Name);
Type* createTypeVar();
Type* createRigidVar(ByteString Name);
InferContext* createInferContext(
InferContext* Parent = nullptr,
TVSet* TVs = new TVSet,
ConstraintSet* Constraints = new ConstraintSet
);
/// Environment manipulation
Scheme* lookup(ByteString Name, SymKind Kind);
/**
* Looks up a type/variable and ensures that it is a monomorphic type.
*
* This method is mainly syntactic sugar to make it clear in the code when a
* monomorphic type is expected.
*
* Note that if the type is not monomorphic the program will abort with a
* stack trace. It wil **not** print a user-friendly error message.
*
* \returns If the type/variable could not be found `nullptr` is returned.
* Otherwise, a [Type] is returned.
*/
Type* lookupMono(ByteString Name, SymKind Kind);
void addBinding(ByteString Name, Scheme* Scm, SymKind Kind);
/// Constraint solving
/**
* The queue that is used during solving to store any unsolved constraints.
*/
std::deque<class Constraint*> Queue;
/**
* Unify two types, using `Source` as source location.
*
* \returns Whether a type variable was assigned a type or not.
*/
bool unify(Type* Left, Type* Right, Node* Source);
void solve(Constraint* Constraint);
/// Helpers
void populate(SourceFile* SF);
/**
* Verifies that type class signatures on type asserts in let-declarations
* correctly declare the right type classes.
*/
void checkTypeclassSigs(Node* N);
Type* instantiate(Scheme* S, Node* Source);
void initialize(Node* N);
public:
Checker(const LanguageConfig& Config, DiagnosticEngine& DE);
/**
* \internal
*/
Type* solveType(Type* Ty);
void check(SourceFile* SF);
inline Type* getBoolType() const {
return BoolType;
Checker(DiagnosticEngine& DE):
DE(DE) {
IntType = new TCon("Int");
BoolType = new TCon("Bool");
StringType = new TCon("String");
}
inline Type* getStringType() const {
return StringType;
}
inline Type* getIntType() const {
Type* getIntType() const {
return IntType;
}
Type* getType(TypedNode* Node);
Type* getBoolType() const {
return BoolType;
}
Type* getStringType() const {
return StringType;
}
TVar* createTVar() {
return new TVar();
}
Type* instantiate(TypeScheme* Scm);
void visitPattern(Pattern* P, Type* Ty, TypeEnv& Out);
ConstraintSet inferSourceFile(TypeEnv& Env, SourceFile* SF);
ConstraintSet inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaration* D);
ConstraintSet inferVariableDeclaration(TypeEnv& Env, VariableDeclaration* Decl);
ConstraintSet inferMany(TypeEnv& Env, std::vector<Node*>& N, Type* RetTy);
ConstraintSet inferElement(TypeEnv& Env, Node* N, Type* RetTy);
std::tuple<ConstraintSet, Type*> inferTypeExpr(TypeEnv& Env, TypeExpression* TE);
std::tuple<ConstraintSet, Type*> inferExpr(TypeEnv& Env, Expression* Expr, Type* RetTy);
ConstraintSet checkExpr(TypeEnv& Env, Expression* Expr, Type* Expected, Type* RetTy);
void solve(const std::vector<Constraint*>& Constraints);
void unifyTypeType(Type* A, Type* B, Node* Source);
void run(SourceFile* SF);
Type* getTypeOfNode(Node* N);
};

View file

@ -1,6 +1,8 @@
#pragma once
#include <cstdlib>
#include "zen/config.hpp"
namespace bolt {

View file

@ -5,13 +5,10 @@
#include "bolt/ByteString.hpp"
#include "bolt/CST.hpp"
#include "bolt/Type.hpp"
namespace bolt {
class Node;
class Type;
class TypeclassSignature;
class Diagnostic;
enum class Color {
@ -160,12 +157,8 @@ class ConsolePrinter {
void writePrefix(const Diagnostic& D);
void writeBinding(const ByteString& Name);
void writeType(std::size_t I);
void writeType(const Type* Ty, const TypePath& Underline);
void writeType(const Type* Ty);
void writeLoc(const TextFile& File, const TextLoc& Loc);
void writeTypeclassName(const ByteString& Name);
void writeTypeclassSignature(const TypeclassSignature& Sig);
void writeType(Type* Ty);
void write(const std::string_view& S);
void write(std::size_t N);

View file

@ -1,6 +1,7 @@
#pragma once
#include <cwchar>
#include <vector>
#include "bolt/ByteString.hpp"
@ -12,15 +13,15 @@ namespace bolt {
enum class DiagnosticKind : unsigned char {
BindingNotFound,
FieldNotFound,
InstanceNotFound,
InvalidTypeToTypeclass,
NotATuple,
TupleIndexOutOfRange,
TypeclassMissing,
// FieldNotFound,
// InstanceNotFound,
// InvalidTypeToTypeclass,
// NotATuple,
// TupleIndexOutOfRange,
// TypeclassMissing,
UnexpectedString,
UnexpectedToken,
UnificationError,
TypeMismatchError,
};
class Diagnostic {
@ -33,7 +34,7 @@ protected:
public:
inline DiagnosticKind getKind() const noexcept {
inline DiagnosticKind getKind() const {
return Kind;
}
@ -41,7 +42,7 @@ public:
return nullptr;
}
virtual unsigned getCode() const noexcept = 0;
virtual unsigned getCode() const = 0;
virtual ~Diagnostic() {}
@ -57,7 +58,7 @@ public:
inline UnexpectedStringDiagnostic(TextFile& File, TextLoc Location, String Actual):
Diagnostic(DiagnosticKind::UnexpectedString), File(File), Location(Location), Actual(Actual) {}
unsigned getCode() const noexcept override {
unsigned getCode() const override {
return 1001;
}
@ -73,7 +74,7 @@ public:
inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector<NodeKind> Expected):
Diagnostic(DiagnosticKind::UnexpectedToken), File(File), Actual(Actual), Expected(Expected) {}
unsigned getCode() const noexcept override {
unsigned getCode() const override {
return 1101;
}
@ -92,153 +93,28 @@ public:
return Initiator;
}
unsigned getCode() const noexcept override {
unsigned getCode() const override {
return 2005;
}
};
class UnificationErrorDiagnostic : public Diagnostic {
class TypeMismatchError : public Diagnostic {
public:
Type* OrigLeft;
Type* OrigRight;
TypePath LeftPath;
TypePath RightPath;
Node* Source;
Type* Left;
Type* Right;
Node* N;
inline UnificationErrorDiagnostic(Type* OrigLeft, Type* OrigRight, TypePath LeftPath, TypePath RightPath, Node* Source):
Diagnostic(DiagnosticKind::UnificationError), OrigLeft(OrigLeft), OrigRight(OrigRight), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {}
inline Type* getLeft() const {
return OrigLeft->resolve(LeftPath);
}
inline Type* getRight() const {
return OrigRight->resolve(RightPath);
}
inline TypeMismatchError(Type* Left, Type* Right, Node* N):
Diagnostic(DiagnosticKind::TypeMismatchError), Left(Left), Right(Right), N(N) {}
inline Node* getNode() const override {
return Source;
return N;
}
unsigned getCode() const noexcept override {
return 2010;
}
};
class TypeclassMissingDiagnostic : public Diagnostic {
public:
TypeclassSignature Sig;
Node* Decl;
inline TypeclassMissingDiagnostic(TypeclassSignature Sig, Node* Decl):
Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {}
inline Node* getNode() const override {
return Decl;
}
unsigned getCode() const noexcept override {
return 2201;
}
};
class InstanceNotFoundDiagnostic : public Diagnostic {
public:
ByteString TypeclassName;
Type* Ty;
Node* Source;
inline InstanceNotFoundDiagnostic(ByteString TypeclassName, Type* Ty, Node* Source):
Diagnostic(DiagnosticKind::InstanceNotFound), TypeclassName(TypeclassName), Ty(Ty), Source(Source) {}
inline Node* getNode() const override {
return Source;
}
unsigned getCode() const noexcept override {
return 2251;
}
};
class TupleIndexOutOfRangeDiagnostic : public Diagnostic {
public:
Type* Tuple;
std::size_t I;
Node* Source;
inline TupleIndexOutOfRangeDiagnostic(Type* Tuple, std::size_t I, Node* Source):
Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I), Source(Source) {}
inline Node * getNode() const override {
return Source;
}
unsigned getCode() const noexcept override {
return 2015;
}
};
class InvalidTypeToTypeclassDiagnostic : public Diagnostic {
public:
Type* Actual;
std::vector<TypeclassId> Classes;
Node* Source;
inline InvalidTypeToTypeclassDiagnostic(Type* Actual, std::vector<TypeclassId> Classes, Node* Source):
Diagnostic(DiagnosticKind::InvalidTypeToTypeclass), Actual(Actual), Classes(Classes), Source(Source) {}
inline Node* getNode() const override {
return Source;
}
unsigned getCode() const noexcept override {
return 2060;
}
};
class FieldNotFoundDiagnostic : public Diagnostic {
public:
ByteString Name;
Type* Ty;
TypePath Path;
Node* Source;
inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source):
Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {}
unsigned getCode() const noexcept override {
return 2017;
}
};
class NotATupleDiagnostic : public Diagnostic {
public:
Type* Ty;
Node* Source;
inline NotATupleDiagnostic(Type* Ty, Node* Source):
Diagnostic(DiagnosticKind::NotATuple), Ty(Ty), Source(Source) {}
inline Node * getNode() const override {
return Source;
}
unsigned getCode() const noexcept override {
return 2016;
unsigned getCode() const override {
return 3001;
}
};

124
include/bolt/Either.hpp Normal file
View file

@ -0,0 +1,124 @@
#pragma once
#include <cstdlib>
#include <concepts>
#include <string>
#include <utility>
#include "bolt/Common.hpp"
namespace bolt {
template<typename T>
concept ErrorLike = requires (T a) {
{ message(a) } -> std::convertible_to<std::string>;
};
template<typename T>
struct Left {
T value;
};
template<typename T>
struct Right {
T value;
};
template<typename L, typename R>
class Either {
bool _is_left;
union {
L _left;
R _right;
};
public:
template<typename L1>
Either(const Left<L1>& left):
_is_left(true), _left(left.value) {}
template<typename R1>
Either(const Right<R1>& right):
_is_left(false), _right(right.value) {}
template<typename L1>
Either(Left<L1>&& left):
_is_left(true), _left(std::move(left.value)) {}
template<typename R2>
Either(Right<R2>&& right):
_is_left(false), _right(std::move(right.value)) {}
Either(const Either& other):
_is_left(_is_left) {
if (other._is_left) {
new (&_left)L(other._left);
} else {
new (&_right)L(other._right);
}
}
Either(Either&& other):
_is_left(std::move(other._is_left)) {
if (_is_left) {
new (&_left)L(std::move(other._left));
} else {
new (&_right)L(std::move(other._right));
}
}
bool is_left() const {
return _is_left;
}
auto left() const {
return _left;
}
auto right() const {
return _right;
}
R&& unwrap() requires ErrorLike<L> {
if (_is_left) {
auto desc = message(_left);
ZEN_PANIC("trying to unwrap a result containing an error: %s", desc.c_str());
}
return std::move(_right);
}
~Either() {
if (_is_left) {
_left.~L();
} else {
_right.~R();
}
}
};
// template<typename L>
// auto left(const L& value) {
// return Left<L> { value };
// }
template<typename L>
auto left(L&& value) {
return Left<L> { std::move(value) };
}
// template<typename R>
// auto right(const R& value) {
// return Right<R> { value };
// }
template<typename R>
auto right(R&& value) {
return Right<R> { std::move(value) };
}
}

View file

@ -5,6 +5,8 @@
#include <filesystem>
#include <unordered_map>
#include "zen/range.hpp"
#include "bolt/Common.hpp"
#include "bolt/Checker.hpp"
#include "bolt/DiagnosticEngine.hpp"
@ -47,12 +49,12 @@ public:
if (Match != TCs.end()) {
return Match->second;
}
return TCs.emplace(SF, Checker { Config, DE }).first->second;
return TCs.emplace(SF, Checker { DE }).first->second;
}
void check() {
for (auto SF: getSourceFiles()) {
getTypeChecker(SF).check(SF);
getTypeChecker(SF).run(SF);
}
}

View file

@ -1,61 +1,32 @@
#pragma once
#include <functional>
#include <optional>
#include <unistd.h>
#include <unordered_set>
#include <unordered_map>
#include <cstddef>
#include <cwchar>
#include <vector>
#include <unordered_set>
#include "zen/config.hpp"
#include "bolt/CST.hpp"
#include "bolt/ByteString.hpp"
namespace bolt {
class Type;
class TCon;
using TypeclassId = ByteString;
using TypeclassContext = std::unordered_set<TypeclassId>;
struct TypeclassSignature {
using TypeclassId = ByteString;
TypeclassId Id;
std::vector<Type*> Params;
bool operator<(const TypeclassSignature& Other) const;
bool operator==(const TypeclassSignature& Other) const;
};
struct TypeSig {
Type* Orig;
Type* Op;
std::vector<Type*> Args;
};
enum class TypeIndexKind {
AppOpType,
AppArgType,
ArrowParamType,
ArrowReturnType,
AppOp,
AppArg,
ArrowLeft,
ArrowRight,
TupleElement,
FieldType,
FieldRestType,
PresentType,
FieldElement,
FieldRest,
PresentElement,
End,
};
class TypeIndex {
protected:
friend class Type;
friend class TypeIterator;
TypeIndexKind Kind;
@ -71,685 +42,202 @@ protected:
public:
bool operator==(const TypeIndex& Other) const noexcept;
void advance(const Type* Ty);
static TypeIndex forFieldType() {
return { TypeIndexKind::FieldType };
static TypeIndex forAppOp() {
return { TypeIndexKind::AppOp };
}
static TypeIndex forFieldRest() {
return { TypeIndexKind::FieldRestType };
static TypeIndex forAppArg() {
return { TypeIndexKind::AppArg };
}
static TypeIndex forArrowParamType() {
return { TypeIndexKind::ArrowParamType };
static TypeIndex forArrowLeft() {
return { TypeIndexKind::ArrowLeft };
}
static TypeIndex forArrowReturnType() {
return { TypeIndexKind::ArrowReturnType };
static TypeIndex forArrowRight() {
return { TypeIndexKind::ArrowRight };
}
static TypeIndex forTupleElement(std::size_t I) {
static TypeIndex forTupleIndex(std::size_t I) {
return { TypeIndexKind::TupleElement, I };
}
static TypeIndex forAppOpType() {
return { TypeIndexKind::AppOpType };
}
static TypeIndex forAppArgType() {
return { TypeIndexKind::AppArgType };
}
static TypeIndex forPresentType() {
return { TypeIndexKind::PresentType };
}
};
class TypeIterator {
friend class Type;
Type* Ty;
TypeIndex Index;
TypeIterator(Type* Ty, TypeIndex Index):
Ty(Ty), Index(Index) {}
public:
TypeIterator& operator++() noexcept {
Index.advance(Ty);
return *this;
}
bool operator==(const TypeIterator& Other) const noexcept {
return Ty == Other.Ty && Index == Other.Index;
}
Type* operator*() {
return Ty;
}
TypeIndex getIndex() const noexcept {
return Index;
}
};
using TypePath = std::vector<TypeIndex>;
using TVSub = std::unordered_map<Type*, Type*>;
using TVSet = std::unordered_set<Type*>;
enum class TypeKind : unsigned char {
enum class TypeKind {
Var,
Con,
Fun,
App,
Arrow,
Tuple,
Field,
Nil,
Absent,
Present,
};
class Type;
class TVar;
class TCon;
class TFun;
class TApp;
struct TCon {
size_t Id;
ByteString DisplayName;
class Type {
protected:
bool operator==(const TCon& Other) const;
TypeKind TK;
};
Type(TypeKind TK):
TK(TK) {}
struct TApp {
Type* Op;
Type* Arg;
public:
bool operator==(const TApp& Other) const;
};
enum class VarKind {
Rigid,
Unification,
};
struct TVar {
VarKind VK;
size_t Id;
TypeclassContext Context;
std::optional<ByteString> Name;
std::optional<TypeclassContext> Provided;
VarKind getKind() const {
return VK;
virtual Type* find() const {
return const_cast<Type*>(this);
}
bool isUni() const {
return VK == VarKind::Unification;
}
bool isRigid() const {
return VK == VarKind::Rigid;
}
bool operator==(const TVar& Other) const;
};
struct TArrow {
Type* ParamType;
Type* ReturnType;
bool operator==(const TArrow& Other) const;
};
struct TTuple {
std::vector<Type*> ElementTypes;
bool operator==(const TTuple& Other) const;
};
struct TNil {
bool operator==(const TNil& Other) const;
};
struct TField {
ByteString Name;
Type* Ty;
Type* RestTy;
bool operator==(const TField& Other) const;
};
struct TAbsent {
bool operator==(const TAbsent& Other) const;
};
struct TPresent {
Type* Ty;
bool operator==(const TPresent& Other) const;
};
struct Type {
TypeKind Kind;
Type* Parent = this;
union {
TCon Con;
TApp App;
TVar Var;
TArrow Arrow;
TTuple Tuple;
TNil Nil;
TField Field;
TAbsent Absent;
TPresent Present;
};
Type(TCon&& Con):
Kind(TypeKind::Con), Con(std::move(Con)) {};
Type(TApp&& App):
Kind(TypeKind::App), App(std::move(App)) {};
Type(TVar&& Var):
Kind(TypeKind::Var), Var(std::move(Var)) {};
Type(TArrow&& Arrow):
Kind(TypeKind::Arrow), Arrow(std::move(Arrow)) {};
Type(TTuple&& Tuple):
Kind(TypeKind::Tuple), Tuple(std::move(Tuple)) {};
Type(TNil&& Nil):
Kind(TypeKind::Nil), Nil(std::move(Nil)) {};
Type(TField&& Field):
Kind(TypeKind::Field), Field(std::move(Field)) {};
Type(TAbsent&& Absent):
Kind(TypeKind::Absent), Absent(std::move(Absent)) {};
Type(TPresent&& Present):
Kind(TypeKind::Present), Present(std::move(Present)) {};
Type(const Type& Other): Kind(Other.Kind) {
switch (Kind) {
case TypeKind::Con:
new (&Con)TCon(Other.Con);
break;
case TypeKind::App:
new (&App)TApp(Other.App);
break;
case TypeKind::Var:
new (&Var)TVar(Other.Var);
break;
case TypeKind::Arrow:
new (&Arrow)TArrow(Other.Arrow);
break;
case TypeKind::Tuple:
new (&Tuple)TTuple(Other.Tuple);
break;
case TypeKind::Nil:
new (&Nil)TNil(Other.Nil);
break;
case TypeKind::Field:
new (&Field)TField(Other.Field);
break;
case TypeKind::Absent:
new (&Absent)TAbsent(Other.Absent);
break;
case TypeKind::Present:
new (&Present)TPresent(Other.Present);
break;
}
}
Type(Type&& Other): Kind(std::move(Other.Kind)) {
switch (Kind) {
case TypeKind::Con:
new (&Con)TCon(std::move(Other.Con));
break;
case TypeKind::App:
new (&App)TApp(std::move(Other.App));
break;
case TypeKind::Var:
new (&Var)TVar(std::move(Other.Var));
break;
case TypeKind::Arrow:
new (&Arrow)TArrow(std::move(Other.Arrow));
break;
case TypeKind::Tuple:
new (&Tuple)TTuple(std::move(Other.Tuple));
break;
case TypeKind::Nil:
new (&Nil)TNil(std::move(Other.Nil));
break;
case TypeKind::Field:
new (&Field)TField(std::move(Other.Field));
break;
case TypeKind::Absent:
new (&Absent)TAbsent(std::move(Other.Absent));
break;
case TypeKind::Present:
new (&Present)TPresent(std::move(Other.Present));
break;
}
}
TypeKind getKind() const {
return Kind;
}
bool isVarRigid() const {
return Kind == TypeKind::Var
&& asVar().getKind() == VarKind::Rigid;
inline TypeKind getKind() const {
return TK;
}
bool isVar() const {
return Kind == TypeKind::Var;
}
TVar& asVar() {
ZEN_ASSERT(Kind == TypeKind::Var);
return Var;
}
const TVar& asVar() const {
ZEN_ASSERT(Kind == TypeKind::Var);
return Var;
}
bool isApp() const {
return Kind == TypeKind::App;
}
TApp& asApp() {
ZEN_ASSERT(Kind == TypeKind::App);
return App;
}
const TApp& asApp() const {
ZEN_ASSERT(Kind == TypeKind::App);
return App;
}
bool isCon() const {
return Kind == TypeKind::Con;
}
TCon& asCon() {
ZEN_ASSERT(Kind == TypeKind::Con);
return Con;
}
const TCon& asCon() const {
ZEN_ASSERT(Kind == TypeKind::Con);
return Con;
}
bool isArrow() const {
return Kind == TypeKind::Arrow;
}
TArrow& asArrow() {
ZEN_ASSERT(Kind == TypeKind::Arrow);
return Arrow;
}
const TArrow& asArrow() const {
ZEN_ASSERT(Kind == TypeKind::Arrow);
return Arrow;
}
bool isTuple() const {
return Kind == TypeKind::Tuple;
}
TTuple& asTuple() {
ZEN_ASSERT(Kind == TypeKind::Tuple);
return Tuple;
}
const TTuple& asTuple() const {
ZEN_ASSERT(Kind == TypeKind::Tuple);
return Tuple;
}
bool isField() const {
return Kind == TypeKind::Field;
}
TField& asField() {
ZEN_ASSERT(Kind == TypeKind::Field);
return Field;
}
const TField& asField() const {
ZEN_ASSERT(Kind == TypeKind::Field);
return Field;
}
bool isAbsent() const {
return Kind == TypeKind::Absent;
}
TAbsent& asAbsent() {
ZEN_ASSERT(Kind == TypeKind::Absent);
return Absent;
}
const TAbsent& asAbsent() const {
ZEN_ASSERT(Kind == TypeKind::Absent);
return Absent;
}
bool isPresent() const {
return Kind == TypeKind::Present;
}
TPresent& asPresent() {
ZEN_ASSERT(Kind == TypeKind::Present);
return Present;
}
const TPresent& asPresent() const {
ZEN_ASSERT(Kind == TypeKind::Present);
return Present;
}
bool isNil() const {
return Kind == TypeKind::Nil;
}
TNil& asNil() {
ZEN_ASSERT(Kind == TypeKind::Nil);
return Nil;
}
const TNil& asNil() const {
ZEN_ASSERT(Kind == TypeKind::Nil);
return Nil;
}
Type* rewrite(std::function<Type*(Type*)> Fn, bool Recursive = true);
Type* resolve(const TypeIndex& Index) const noexcept;
Type* resolve(const TypePath& Path) noexcept {
Type* Ty = this;
for (auto El: Path) {
Ty = Ty->resolve(El);
}
return Ty;
}
void set(Type* Ty) {
auto Root = find();
// It is not possible to set a solution twice.
if (isVar()) {
ZEN_ASSERT(Root->isVar());
}
Root->Parent = Ty;
}
Type* find() const {
Type* Curr = const_cast<Type*>(this);
for (;;) {
auto Keep = Curr->Parent;
if (Keep == Curr) {
return Keep;
}
Curr->Parent = Keep->Parent;
Curr = Keep;
}
return TK == TypeKind::Var;
}
bool operator==(const Type& Other) const;
void destroy() {
switch (Kind) {
case TypeKind::Con:
App.~TApp();
break;
case TypeKind::App:
App.~TApp();
break;
case TypeKind::Var:
Var.~TVar();
break;
case TypeKind::Arrow:
Arrow.~TArrow();
break;
case TypeKind::Tuple:
Tuple.~TTuple();
break;
case TypeKind::Nil:
Nil.~TNil();
break;
case TypeKind::Field:
Field.~TField();
break;
case TypeKind::Absent:
Absent.~TAbsent();
break;
case TypeKind::Present:
Present.~TPresent();
break;
}
}
std::string toString() const;
Type& operator=(Type& Other) {
destroy();
Kind = Other.Kind;
switch (Kind) {
case TypeKind::Con:
App = Other.App;
break;
case TypeKind::App:
App = Other.App;
break;
case TypeKind::Var:
Var = Other.Var;
break;
case TypeKind::Arrow:
Arrow = Other.Arrow;
break;
case TypeKind::Tuple:
Tuple = Other.Tuple;
break;
case TypeKind::Nil:
Nil = Other.Nil;
break;
case TypeKind::Field:
Field = Other.Field;
break;
case TypeKind::Absent:
Absent = Other.Absent;
break;
case TypeKind::Present:
Present = Other.Present;
break;
}
return *this;
}
Type* resolve(const TypePath& P);
bool hasTypeVar(Type* TV) const;
TVar* asVar();
const TVar* asVar() const;
TypeIterator begin();
TypeIterator end();
TFun* asFun();
const TFun* asFun() const;
TypeIndex getStartIndex() const;
TypeIndex getEndIndex() const;
Type* substitute(const TVSub& Sub);
void visitEachChild(std::function<void(Type*)> Proc);
TVSet getTypeVars();
~Type() {
destroy();
}
static Type* buildArrow(std::vector<Type*> ParamTypes, Type* ReturnType) {
Type* Curr = ReturnType;
for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) {
Curr = new Type(TArrow(*Iter, Curr));
}
return Curr;
}
TCon* asCon();
const TCon* asCon() const;
};
template<bool IsConst>
class TypeVisitorBase {
protected:
class TVar : public Type {
template<typename T>
using C = std::conditional<IsConst, const T, T>::type;
virtual void enterType(C<Type>* Ty) {}
virtual void exitType(C<Type>* Ty) {}
// virtual void visitType(C<Type>* Ty) {
// visitEachChild(Ty);
// }
virtual void visitVarType(C<TVar>& Ty) {
}
virtual void visitAppType(C<TApp>& Ty) {
visit(Ty.Op);
visit(Ty.Arg);
}
virtual void visitPresentType(C<TPresent>& Ty) {
visit(Ty.Ty);
}
virtual void visitConType(C<TCon>& Ty) {
}
virtual void visitArrowType(C<TArrow>& Ty) {
visit(Ty.ParamType);
visit(Ty.ReturnType);
}
virtual void visitTupleType(C<TTuple>& Ty) {
for (auto ElTy: Ty.ElementTypes) {
visit(ElTy);
}
}
virtual void visitAbsentType(C<TAbsent>& Ty) {
}
virtual void visitFieldType(C<TField>& Ty) {
visit(Ty.Ty);
visit(Ty.RestTy);
}
virtual void visitNilType(C<TNil>& Ty) {
}
Type* Parent = this;
public:
void visitEachChild(C<Type>* Ty) {
switch (Ty->getKind()) {
case TypeKind::Var:
case TypeKind::Absent:
case TypeKind::Nil:
case TypeKind::Con:
break;
case TypeKind::Arrow:
{
auto& Arrow = Ty->asArrow();
visit(Arrow->ParamType);
visit(Arrow->ReturnType);
break;
}
case TypeKind::Tuple:
{
auto& Tuple = Ty->asTuple();
for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) {
visit(Tuple->ElementTypes[I]);
}
break;
}
case TypeKind::App:
{
auto& App = Ty->asApp();
visit(App->Op);
visit(App->Arg);
break;
}
case TypeKind::Field:
{
auto& Field = Ty->asField();
visit(Field->Ty);
visit(Field->RestTy);
break;
}
case TypeKind::Present:
{
auto& Present = Ty->asPresent();
visit(Present->Ty);
break;
}
}
TVar():
Type(TypeKind::Var) {}
void set(Type* Ty) {
auto Root = find();
// It is not possible to set a solution twice.
ZEN_ASSERT(Root->isVar());
static_cast<TVar*>(Root)->Parent = Ty;
}
void visit(C<Type>* Ty) {
// Always look at the most solved solution
Ty = Ty->find();
enterType(Ty);
switch (Ty->getKind()) {
case TypeKind::Present:
visitPresentType(Ty->asPresent());
break;
case TypeKind::Absent:
visitAbsentType(Ty->asAbsent());
break;
case TypeKind::Nil:
visitNilType(Ty->asNil());
break;
case TypeKind::Field:
visitFieldType(Ty->asField());
break;
case TypeKind::Con:
visitConType(Ty->asCon());
break;
case TypeKind::Arrow:
visitArrowType(Ty->asArrow());
break;
case TypeKind::Var:
visitVarType(Ty->asVar());
break;
case TypeKind::Tuple:
visitTupleType(Ty->asTuple());
break;
case TypeKind::App:
visitAppType(Ty->asApp());
break;
Type* find() const override {
TVar* Curr = const_cast<TVar*>(this);
for (;;) {
auto Keep = Curr->Parent;
if (Keep == Curr || !Keep->isVar()) {
return Keep;
}
auto Keep2 = static_cast<TVar*>(Keep);
Curr->Parent = Keep2->Parent;
Curr = Keep2;
}
exitType(Ty);
}
virtual ~TypeVisitorBase() {}
};
using TypeVisitor = TypeVisitorBase<false>;
using ConstTypeVisitor = TypeVisitorBase<true>;
class TCon : public Type {
ByteString Name;
public:
TCon(ByteString Name):
Type(TypeKind::Con), Name(Name) {}
ByteStringView getName() const {
return Name;
}
};
class TFun : public Type {
Type* Left;
Type* Right;
public:
TFun(Type* Left, Type* Right):
Type(TypeKind::Fun), Left(Left), Right(Right) {}
Type* getLeft() const {
return Left;
}
Type* getRight() const {
return Right;
}
};
class TApp : public Type {
Type* Left;
Type* Right;
public:
TApp(Type* Left, Type* Right):
Type(TypeKind::App), Left(Left), Right(Right) {}
Type* getLeft() const {
return Left;
}
Type* getRight() const {
return Right;
}
};
struct TypeScheme {
std::unordered_set<TVar*> Unbound;
Type* Ty;
Type* getType() const {
return Ty;
}
};
class TypeVisitor {
public:
void visit(Type* Ty);
virtual void visitVar(TVar* TV) {
}
virtual void visitApp(TApp* App) {
visit(App->getLeft());
visit(App->getRight());
}
virtual void visitCon(TCon* Con) {
}
virtual void visitFun(TFun* Fun) {
visit(Fun->getLeft());
visit(Fun->getRight());
}
};
}

View file

@ -1,6 +1,4 @@
#include "zen/config.hpp"
#include "bolt/CST.hpp"
#include "bolt/CSTVisitor.hpp"
@ -512,8 +510,8 @@ Token* ReturnStatement::getFirstToken() const {
}
Token* ReturnStatement::getLastToken() const {
if (Expression) {
return Expression->getLastToken();
if (E) {
return E->getLastToken();
}
return ReturnKeyword;
}
@ -1036,5 +1034,12 @@ SymbolPath ReferenceExpression::getSymbolPath() const {
return SymbolPath { ModuleNames, Name.getCanonicalText() };
}
bool TypedNode::classof(Node* N) {
return Expression::classof(N)
|| TypeExpression::classof(N)
|| FunctionDeclaration::classof(N)
|| VariableDeclaration::classof(N);
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,11 +1,9 @@
// FIXME writeExcerpt does not work well with the last line in a file
#include <sstream>
#include <functional>
#include <cmath>
#include "zen/config.hpp"
#include "bolt/CST.hpp"
#include "bolt/Type.hpp"
#include "bolt/Diagnostics.hpp"
@ -182,6 +180,8 @@ static std::string describe(NodeKind Type) {
return "a variant";
case NodeKind::MatchCase:
return "a match-arm";
case NodeKind::LetExprBody:
return "the body of a let-declaration";
default:
ZEN_UNREACHABLE
}
@ -199,79 +199,6 @@ static std::string describe(Token* T) {
}
}
std::string describe(const Type* Ty) {
Ty = Ty->find();
switch (Ty->getKind()) {
case TypeKind::Var:
{
auto TV = Ty->asVar();
if (TV.isRigid()) {
return *TV.Name;
}
return "a" + std::to_string(TV.Id);
}
case TypeKind::Arrow:
{
auto Y = Ty->asArrow();
std::ostringstream Out;
Out << describe(Y.ParamType) << " -> " << describe(Y.ReturnType);
return Out.str();
}
case TypeKind::Con:
{
auto Y = Ty->asCon();
return Y.DisplayName;
}
case TypeKind::App:
{
auto Y = Ty->asApp();
return describe(Y.Op) + " " + describe(Y.Arg);
}
case TypeKind::Tuple:
{
std::ostringstream Out;
auto Y = Ty->asTuple();
Out << "(";
if (Y.ElementTypes.size()) {
auto Iter = Y.ElementTypes.begin();
Out << describe(*Iter++);
while (Iter != Y.ElementTypes.end()) {
Out << ", " << describe(*Iter++);
}
}
Out << ")";
return Out.str();
}
case TypeKind::Nil:
return "{}";
case TypeKind::Absent:
return "Abs";
case TypeKind::Present:
{
auto Y = Ty->asPresent();
return describe(Y.Ty);
}
case TypeKind::Field:
{
auto Y = Ty->asField();
std::ostringstream out;
out << "{ " << Y.Name << ": " << describe(Y.Ty);
Ty = Y.RestTy;
while (Ty->getKind() == TypeKind::Field) {
auto Y = Ty->asField();
out << "; " + Y.Name + ": " + describe(Y.Ty);
Ty = Y.RestTy;
}
if (Ty->getKind() != TypeKind::Nil) {
out << "; " + describe(Ty);
}
out << " }";
return out.str();
}
}
ZEN_UNREACHABLE
}
void writeForegroundANSI(Color C, std::ostream& Out) {
switch (C) {
case Color::None:
@ -533,153 +460,6 @@ void ConsolePrinter::writeBinding(const ByteString& Name) {
write("'");
}
void ConsolePrinter::writeType(const Type* Ty) {
TypePath Path;
writeType(Ty, Path);
}
void ConsolePrinter::writeType(const Type* Ty, const TypePath& Underline) {
setForegroundColor(Color::Green);
class TypePrinter : public ConstTypeVisitor {
TypePath Path;
ConsolePrinter& W;
const TypePath& Underline;
public:
TypePrinter(ConsolePrinter& W, const TypePath& Underline):
W(W), Underline(Underline) {}
bool shouldUnderline() const {
return !Underline.empty() && Path == Underline;
}
void enterType(const Type* Ty) override {
if (shouldUnderline()) {
W.setUnderline(true);
}
}
void exitType(const Type* Ty) override {
if (shouldUnderline()) {
W.setUnderline(false); // FIXME Should set to old value
}
}
void visitAppType(const TApp& Ty) override {
Path.push_back(TypeIndex::forAppOpType());
visit(Ty.Op);
Path.pop_back();
W.write(" ");
Path.push_back(TypeIndex::forAppArgType());
visit(Ty.Arg);
Path.pop_back();
}
void visitVarType(const TVar& Ty) override {
if (Ty.isRigid()) {
W.write(*Ty.Name);
return;
}
W.write("a");
W.write(Ty.Id);
}
void visitConType(const TCon& Ty) override {
W.write(Ty.DisplayName);
}
void visitArrowType(const TArrow& Ty) override {
Path.push_back(TypeIndex::forArrowParamType());
visit(Ty.ParamType);
Path.pop_back();
W.write(" -> ");
Path.push_back(TypeIndex::forArrowReturnType());
visit(Ty.ReturnType);
Path.pop_back();
}
void visitTupleType(const TTuple& Ty) override {
W.write("(");
if (Ty.ElementTypes.size()) {
auto Iter = Ty.ElementTypes.begin();
Path.push_back(TypeIndex::forTupleElement(0));
visit(*Iter++);
Path.pop_back();
std::size_t I = 1;
while (Iter != Ty.ElementTypes.end()) {
W.write(", ");
Path.push_back(TypeIndex::forTupleElement(I++));
visit(*Iter++);
Path.pop_back();
}
}
W.write(")");
}
void visitNilType(const TNil& Ty) override {
W.write("{}");
}
void visitAbsentType(const TAbsent& Ty) override {
W.write("Abs");
}
void visitPresentType(const TPresent& Ty) override {
Path.push_back(TypeIndex::forPresentType());
visit(Ty.Ty);
Path.pop_back();
}
void visitFieldType(const TField& Ty) override {
W.write("{ ");
W.write(Ty.Name);
W.write(": ");
Path.push_back(TypeIndex::forFieldType());
visit(Ty.Ty);
Path.pop_back();
auto Ty2 = Ty.RestTy;
Path.push_back(TypeIndex::forFieldRest());
std::size_t I = 1;
while (Ty2->isField()) {
auto Y = Ty2->asField();
W.write("; ");
W.write(Y.Name);
W.write(": ");
Path.push_back(TypeIndex::forFieldType());
visit(Y.Ty);
Path.pop_back();
Ty2 = Y.RestTy;
Path.push_back(TypeIndex::forFieldRest());
++I;
}
if (Ty2->getKind() != TypeKind::Nil) {
W.write("; ");
visit(Ty2);
}
W.write(" }");
for (auto K = 0; K < I; K++) {
Path.pop_back();
}
}
};
TypePrinter P { *this, Underline };
P.visit(Ty);
resetStyles();
}
void ConsolePrinter::writeType(std::size_t I) {
setForegroundColor(Color::Green);
write(I);
resetStyles();
}
void ConsolePrinter::writeNode(const Node* N) {
auto Range = N->getRange();
writeExcerpt(N->getSourceFile()->getTextFile(), Range, Range, Color::Red);
@ -703,19 +483,42 @@ void ConsolePrinter::writePrefix(const Diagnostic& D) {
resetStyles();
}
void ConsolePrinter::writeTypeclassName(const ByteString& Name) {
setForegroundColor(Color::Magenta);
write(Name);
resetStyles();
}
void ConsolePrinter::writeTypeclassSignature(const TypeclassSignature& Sig) {
setForegroundColor(Color::Magenta);
write(Sig.Id);
for (auto TV: Sig.Params) {
write(" ");
write(describe(TV));
}
void ConsolePrinter::writeType(Type* Ty) {
std::function<void(Type*)> visit = [&](auto Ty) {
switch (Ty->getKind()) {
case TypeKind::Var:
{
auto T = static_cast<TVar*>(Ty);
// FIXME
write("α");
break;
}
case TypeKind::Con:
{
auto T = static_cast<TCon*>(Ty);
write(T->getName());
break;
}
case TypeKind::Fun:
{
auto T = static_cast<TFun*>(Ty);
visit(T->getLeft());
write(" -> ");
visit(T->getRight());
break;
}
case TypeKind::App:
{
auto T = static_cast<TApp*>(Ty);
visit(T->getLeft());
write(" ");
visit(T->getRight());
break;
}
}
};
setForegroundColor(Color::Green);
visit(Ty);
resetStyles();
}
@ -799,11 +602,14 @@ void ConsolePrinter::writeDiagnostic(const Diagnostic& D) {
return;
}
case DiagnosticKind::UnificationError:
case DiagnosticKind::TypeMismatchError:
{
auto& E = static_cast<const UnificationErrorDiagnostic&>(D);
auto Left = E.OrigLeft->resolve(E.LeftPath);
auto Right = E.OrigRight->resolve(E.RightPath);
auto& E = static_cast<const TypeMismatchError&>(D);
// auto Left = E.OrigLeft->resolve(E.LeftPath);
// auto Right = E.OrigRight->resolve(E.RightPath);
auto Left = E.Left;
auto Right = E.Right;
auto S = E.getNode();
writePrefix(E);
write("the types ");
writeType(Left);
@ -815,7 +621,7 @@ void ConsolePrinter::writeDiagnostic(const Diagnostic& D) {
write(" info: ");
resetStyles();
write("due to an equality constraint on ");
write(describe(E.Source->getKind()));
write(describe(S->getKind()));
write(":\n\n");
// write(" - left type ");
// writeType(E.OrigLeft, E.LeftPath);
@ -823,7 +629,7 @@ void ConsolePrinter::writeDiagnostic(const Diagnostic& D) {
// write(" - right type ");
// writeType(E.OrigRight, E.RightPath);
// write("\n\n");
writeNode(E.Source);
writeNode(S);
write("\n");
// if (E.Left != E.OrigLeft) {
// setForegroundColor(Color::Yellow);
@ -850,87 +656,6 @@ void ConsolePrinter::writeDiagnostic(const Diagnostic& D) {
return;
}
case DiagnosticKind::TypeclassMissing:
{
auto& E = static_cast<const TypeclassMissingDiagnostic&>(D);
writePrefix(E);
write("the type class ");
writeTypeclassSignature(E.Sig);
write(" is missing from the declaration's type signature\n\n");
writeNode(E.Decl);
write("\n\n");
return;
}
case DiagnosticKind::InstanceNotFound:
{
auto& E = static_cast<const InstanceNotFoundDiagnostic&>(D);
writePrefix(E);
write("a type class instance ");
writeTypeclassName(E.TypeclassName);
write(" ");
writeType(E.Ty);
write(" was not found.\n\n");
writeNode(E.Source);
write("\n");
return;
}
case DiagnosticKind::TupleIndexOutOfRange:
{
auto& E = static_cast<const TupleIndexOutOfRangeDiagnostic&>(D);
writePrefix(E);
write("the index ");
writeType(E.I);
write(" is out of range for tuple ");
writeType(E.Tuple);
write("\n\n");
writeNode(E.Source);
write("\n");
return;
}
case DiagnosticKind::InvalidTypeToTypeclass:
{
auto& E = static_cast<const InvalidTypeToTypeclassDiagnostic&>(D);
writePrefix(E);
write("the type ");
writeType(E.Actual);
write(" was applied to type class names ");
bool First = true;
for (auto Class: E.Classes) {
if (First) First = false;
else write(", ");
writeTypeclassName(Class);
}
write(" but this is invalid\n\n");
return;
}
case DiagnosticKind::FieldNotFound:
{
auto& E = static_cast<const FieldNotFoundDiagnostic&>(D);
writePrefix(E);
write("the field '");
write(E.Name);
write("' was required in one type but not found in another\n\n");
writeNode(E.Source);
write("\n");
return;
}
case DiagnosticKind::NotATuple:
{
auto& E = static_cast<const NotATupleDiagnostic&>(D);
writePrefix(E);
write("the type ");
writeType(E.Ty);
write(" is not a tuple.\n\n");
writeNode(E.Source);
write("\n");
return;
}
}
ZEN_UNREACHABLE

View file

@ -2,10 +2,9 @@
// FIXME writeExcerpt does not work well with the last line in a file
#include <sstream>
#include <algorithm>
#include <cmath>
#include "zen/config.hpp"
#include "bolt/CST.hpp"
#include "bolt/Type.hpp"
#include "bolt/DiagnosticEngine.hpp"

View file

@ -1,6 +1,4 @@
#include "zen/range.hpp"
#include "bolt/CST.hpp"
#include "bolt/Evaluator.hpp"
@ -62,8 +60,24 @@ Value Evaluator::apply(Value Op, std::vector<Value> Args) {
{
auto Fn = Op.getDeclaration();
Env NewEnv;
for (auto [Param, Arg]: zen::zip(Fn->getParams(), Args)) {
assignPattern(Param->Pattern, Arg, NewEnv);
auto Params= Fn->getParams();
auto ParamIter = Params.begin();
auto ParamsEnd = Params.end();
auto ArgIter = Args.begin();
auto ArgsEnd= Args.end();
for (;;) {
if (ParamIter == ParamsEnd && ArgIter == ArgsEnd) {
break;
}
if (ParamIter == ParamsEnd) {
// TODO Make this a soft failure
ZEN_PANIC("Too much arguments supplied to function call.");
}
if (ArgIter == ArgsEnd) {
// TODO Make this a soft failure
ZEN_PANIC("Too much few arguments supplied to function call.");
}
assignPattern((*ParamIter)->Pattern, *ArgIter, NewEnv);
}
switch (Fn->getBody()->getKind()) {
case NodeKind::LetExprBody:

View file

@ -2,12 +2,11 @@
#include <cmath>
#include <memory>
#include "llvm/IR/Value.h"
#include "LLVMCodeGen.hpp"
#include "bolt/CST.hpp"
#include "bolt/CSTVisitor.hpp"
#include "LLVMCodeGen.hpp"
namespace bolt {
LLVMCodeGen::LLVMCodeGen(llvm::LLVMContext* TheContext):

View file

@ -1,7 +1,8 @@
#pragma once
#include "llvm/IR/Value.h"
#include <memory>
#include "llvm/IR/IRBuilder.h"
namespace bolt {

View file

@ -4,6 +4,8 @@
#include <tuple>
#include <vector>
#include "zen/config.hpp"
#include "bolt/Common.hpp"
#include "bolt/CST.hpp"
#include "bolt/Scanner.hpp"

3
src/Program.cc Normal file
View file

@ -0,0 +1,3 @@
#include "bolt/Program.hpp"

View file

@ -1,8 +1,6 @@
#include <unordered_map>
#include "zen/config.hpp"
#include "bolt/Common.hpp"
#include "bolt/Text.hpp"
#include "bolt/Integer.hpp"

102
src/Type.cc Normal file
View file

@ -0,0 +1,102 @@
#include "zen/config.hpp"
#include "bolt/Type.hpp"
namespace bolt {
Type* Type::resolve(const TypePath& P) {
auto Ty = this;
for (auto& Index: P) {
switch (Index.Kind) {
case TypeIndexKind::AppOp:
Ty = static_cast<TApp*>(Ty)->getLeft();
break;
case TypeIndexKind::AppArg:
Ty = static_cast<TApp*>(Ty)->getRight();
break;
case TypeIndexKind::ArrowLeft:
Ty = static_cast<TFun*>(Ty)->getLeft();
break;
case TypeIndexKind::ArrowRight:
Ty = static_cast<TFun*>(Ty)->getRight();
break;
default:
ZEN_UNREACHABLE
}
}
return Ty;
}
bool Type::operator==(const Type& Other) const {
if (Other.getKind() != TK) {
return false;
}
switch (TK) {
case TypeKind::App:
{
auto A1 = static_cast<const TApp&>(*this);
auto A2 = static_cast<const TApp&>(Other);
return *A1.getLeft() == *A2.getLeft() && *A1.getRight() == *A2.getRight();
}
case TypeKind::Var:
return this == &Other;
case TypeKind::Fun:
{
auto F1 = static_cast<const TFun&>(*this);
auto F2 = static_cast<const TFun&>(Other);
return *F1.getLeft() == *F2.getLeft() && *F1.getRight() == *F2.getRight();
}
case TypeKind::Con:
{
auto C1 = static_cast<const TCon&>(*this);
auto C2 = static_cast<const TCon&>(Other);
return C1.getName() == C2.getName();
}
}
}
std::string Type::toString() const {
switch (TK) {
case TypeKind::App:
{
auto A = static_cast<const TApp*>(this);
return A->getLeft()->toString() + " " + A->getRight()->toString();
}
case TypeKind::Con:
{
auto C = static_cast<const TCon*>(this);
return std::string(C->getName());
}
case TypeKind::Fun:
{
auto F = static_cast<const TFun*>(this);
return F->getLeft()->toString() + " -> " + F->getRight()->toString();
}
case TypeKind::Var:
return "α";
}
}
TVar* Type::asVar() {
return static_cast<TVar*>(this);
}
void TypeVisitor::visit(Type* Ty) {
switch (Ty->getKind()) {
case TypeKind::App:
visitApp(static_cast<TApp*>(Ty));
break;
case TypeKind::Con:
visitCon(static_cast<TCon*>(Ty));
break;
case TypeKind::Fun:
visitFun(static_cast<TFun*>(Ty));
break;
case TypeKind::Var:
visitVar(static_cast<TVar*>(Ty));
break;
}
}
}

View file

@ -1,336 +0,0 @@
#include "bolt/Type.hpp"
#include <cwchar>
#include <sys/wait.h>
#include <vector>
#include "zen/range.hpp"
namespace bolt {
bool TypeclassSignature::operator<(const TypeclassSignature& Other) const {
if (Id < Other.Id) {
return true;
}
ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1);
return Params[0]->asCon().Id < Other.Params[0]->asCon().Id;
}
bool TypeclassSignature::operator==(const TypeclassSignature& Other) const {
ZEN_ASSERT(Params.size() == 1);
ZEN_ASSERT(Other.Params.size() == 1);
return Id == Other.Id && Params[0]->asCon().Id == Other.Params[0]->asCon().Id;
}
bool TypeIndex::operator==(const TypeIndex& Other) const noexcept {
if (Kind != Other.Kind) {
return false;
}
switch (Kind) {
case TypeIndexKind::ArrowParamType:
case TypeIndexKind::TupleElement:
return I == Other.I;
default:
return true;
}
}
bool TCon::operator==(const TCon& Other) const {
return Id == Other.Id;
}
bool TApp::operator==(const TApp& Other) const {
return *Op == *Other.Op && *Arg == *Other.Arg;
}
bool TVar::operator==(const TVar& Other) const {
return Id == Other.Id;
}
bool TArrow::operator==(const TArrow& Other) const {
return *ParamType == *Other.ParamType
&& *ReturnType == *Other.ReturnType;
}
bool TTuple::operator==(const TTuple& Other) const {
for (auto [T1, T2]: zen::zip(ElementTypes, Other.ElementTypes)) {
if (*T1 != *T2) {
return false;
}
}
return true;
}
bool TNil::operator==(const TNil& Other) const {
return true;
}
bool TField::operator==(const TField& Other) const {
return Name == Other.Name && *Ty == *Other.Ty && *RestTy == *Other.RestTy;
}
bool TAbsent::operator==(const TAbsent& Other) const {
return true;
}
bool TPresent::operator==(const TPresent& Other) const {
return *Ty == *Other.Ty;
}
bool Type::operator==(const Type& Other) const {
if (Kind != Other.Kind) {
return false;
}
switch (Kind) {
case TypeKind::Var:
return Var == Other.Var;
case TypeKind::Con:
return Con == Other.Con;
case TypeKind::Present:
return Present == Other.Present;
case TypeKind::Absent:
return Absent == Other.Absent;
case TypeKind::Arrow:
return Arrow == Other.Arrow;
case TypeKind::Field:
return Field == Other.Field;
case TypeKind::Nil:
return Nil == Other.Nil;
case TypeKind::Tuple:
return Tuple == Other.Tuple;
case TypeKind::App:
return App == Other.App;
}
ZEN_UNREACHABLE
}
void Type::visitEachChild(std::function<void(Type*)> Proc) {
switch (Kind) {
case TypeKind::Var:
case TypeKind::Absent:
case TypeKind::Nil:
case TypeKind::Con:
break;
case TypeKind::Arrow:
{
Proc(Arrow.ParamType);
Proc(Arrow.ReturnType);
break;
}
case TypeKind::Tuple:
{
for (auto I = 0; I < Tuple.ElementTypes.size(); ++I) {
Proc(Tuple.ElementTypes[I]);
}
break;
}
case TypeKind::App:
{
Proc(App.Op);
Proc(App.Arg);
break;
}
case TypeKind::Field:
{
Proc(Field.Ty);
Proc(Field.RestTy);
break;
}
case TypeKind::Present:
{
Proc(Present.Ty);
break;
}
}
}
Type* Type::rewrite(std::function<Type*(Type*)> Fn, bool Recursive) {
auto Ty2 = Fn(this);
if (this != Ty2) {
if (Recursive) {
return Ty2->rewrite(Fn, Recursive);
}
return Ty2;
}
switch (Kind) {
case TypeKind::Var:
return Ty2;
case TypeKind::Arrow:
{
auto Arrow = Ty2->asArrow();
bool Changed = false;
Type* NewParamType = Arrow.ParamType->rewrite(Fn, Recursive);
if (NewParamType != Arrow.ParamType) {
Changed = true;
}
auto NewRetTy = Arrow.ReturnType->rewrite(Fn, Recursive);
if (NewRetTy != Arrow.ReturnType) {
Changed = true;
}
return Changed ? new Type(TArrow(NewParamType, NewRetTy)) : Ty2;
}
case TypeKind::Con:
return Ty2;
case TypeKind::App:
{
auto App = Ty2->asApp();
auto NewOp = App.Op->rewrite(Fn, Recursive);
auto NewArg = App.Arg->rewrite(Fn, Recursive);
if (NewOp == App.Op && NewArg == App.Arg) {
return Ty2;
}
return new Type(TApp(NewOp, NewArg));
}
case TypeKind::Tuple:
{
auto Tuple = Ty2->asTuple();
bool Changed = false;
std::vector<Type*> NewElementTypes;
for (auto Ty: Tuple.ElementTypes) {
auto NewElementType = Ty->rewrite(Fn, Recursive);
if (NewElementType != Ty) {
Changed = true;
}
NewElementTypes.push_back(NewElementType);
}
return Changed ? new Type(TTuple(NewElementTypes)) : Ty2;
}
case TypeKind::Nil:
return Ty2;
case TypeKind::Absent:
return Ty2;
case TypeKind::Field:
{
auto Field = Ty2->asField();
bool Changed = false;
auto NewTy = Field.Ty->rewrite(Fn, Recursive);
if (NewTy != Field.Ty) {
Changed = true;
}
auto NewRestTy = Field.RestTy->rewrite(Fn, Recursive);
if (NewRestTy != Field.RestTy) {
Changed = true;
}
return Changed ? new Type(TField(Field.Name, NewTy, NewRestTy)) : Ty2;
}
case TypeKind::Present:
{
auto Present = Ty2->asPresent();
auto NewTy = Present.Ty->rewrite(Fn, Recursive);
if (NewTy == Present.Ty) {
return Ty2;
}
return new Type(TPresent(NewTy));
}
}
ZEN_UNREACHABLE
}
Type* Type::substitute(const TVSub &Sub) {
return rewrite([&](auto Ty) {
if (Ty->isVar()) {
auto Match = Sub.find(Ty);
return Match != Sub.end() ? Match->second->substitute(Sub) : Ty;
}
return Ty;
}, false);
}
Type* Type::resolve(const TypeIndex& Index) const noexcept {
switch (Index.Kind) {
case TypeIndexKind::PresentType:
return this->asPresent().Ty;
case TypeIndexKind::AppOpType:
return this->asApp().Op;
case TypeIndexKind::AppArgType:
return this->asApp().Arg;
case TypeIndexKind::TupleElement:
return this->asTuple().ElementTypes[Index.I];
case TypeIndexKind::ArrowParamType:
return this->asArrow().ParamType;
case TypeIndexKind::ArrowReturnType:
return this->asArrow().ReturnType;
case TypeIndexKind::FieldType:
return this->asField().Ty;
case TypeIndexKind::FieldRestType:
return this->asField().RestTy;
case TypeIndexKind::End:
ZEN_UNREACHABLE
}
ZEN_UNREACHABLE
}
TVSet Type::getTypeVars() {
TVSet Out;
std::function<void(Type*)> visit = [&](Type* Ty) {
if (Ty->isVar()) {
Out.emplace(Ty);
return;
}
Ty->visitEachChild(visit);
};
visit(this);
return Out;
}
TypeIterator Type::begin() {
return TypeIterator { this, getStartIndex() };
}
TypeIterator Type::end() {
return TypeIterator { this, getEndIndex() };
}
TypeIndex Type::getStartIndex() const {
switch (Kind) {
case TypeKind::Arrow:
return TypeIndex::forArrowParamType();
case TypeKind::Tuple:
{
if (asTuple().ElementTypes.empty()) {
return TypeIndex(TypeIndexKind::End);
}
return TypeIndex::forTupleElement(0);
}
case TypeKind::Field:
return TypeIndex::forFieldType();
default:
return TypeIndex(TypeIndexKind::End);
}
}
TypeIndex Type::getEndIndex() const {
return TypeIndex(TypeIndexKind::End);
}
bool Type::hasTypeVar(Type* TV) const {
switch (Kind) {
case TypeKind::Var:
return Var.Id == TV->asVar().Id;
case TypeKind::Con:
case TypeKind::Absent:
case TypeKind::Nil:
return false;
case TypeKind::App:
return App.Op->hasTypeVar(TV) || App.Arg->hasTypeVar(TV);
case TypeKind::Tuple:
for (auto Ty: Tuple.ElementTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
case TypeKind::Field:
return Field.Ty->hasTypeVar(TV) || Field.RestTy->hasTypeVar(TV);
case TypeKind::Arrow:
return Arrow.ParamType->hasTypeVar(TV) || Arrow.ReturnType->hasTypeVar(TV);
case TypeKind::Present:
return Present.Ty->hasTypeVar(TV);
}
ZEN_UNREACHABLE
}
}

View file

@ -6,7 +6,6 @@
#include <algorithm>
#include <map>
#include "zen/config.hpp"
#include "zen/po.hpp"
#include "bolt/CST.hpp"
@ -113,11 +112,11 @@ int main(int Argc, const char* Argv[]) {
void visitExpression(Expression* N) {
for (auto A: N->Annotations) {
if (A->getKind() == NodeKind::TypeAssertAnnotation) {
auto Left = C.getType(N);
auto Left = C.getTypeOfNode(N);
auto Right = static_cast<TypeAssertAnnotation*>(A)->getTypeExpression()->getType();
std::cerr << "verify " << describe(Left) << " == " << describe(Right) << std::endl;
std::cerr << "verify " << Left->toString() << " == " << Right->toString() << std::endl;
if (*Left != *Right) {
DE.add<UnificationErrorDiagnostic>(Left, Right, TypePath(), TypePath(), A);
DE.add<TypeMismatchError>(Left, Right, A);
}
}
}

7
x.py
View file

@ -167,7 +167,8 @@ def build_bolt(c_path: str | None = None, cxx_path: str | None = None) -> None:
'CMAKE_BUILD_TYPE': 'Debug',
'BOLT_ENABLE_TESTS': True,
'ZEN_ENABLE_TESTS': False,
'LLVM_CONFIG': str(llvm_config_path)
#'LLVM_CONFIG': str(llvm_config_path),
'LLVM_TARGETS_TO_BUILD': 'X86',
}
if c_path is not None:
defines['CMAKE_C_COMPILER'] = c_path
@ -196,13 +197,13 @@ c_path = None
cxx_path = None
if os.name == 'posix':
clang_c_path = shutil.which('clangj')
clang_c_path = shutil.which('clang')
clang_cxx_path = shutil.which('clang++')
if clang_c_path is not None and clang_cxx_path is not None and (force == NONE or force == CLANG):
c_path = clang_c_path
cxx_path = clang_cxx_path
else:
for version in [ '18' ]:
for version in [ '18', '19' ]:
clang_c_path = shutil.which(f'clang-{version}')
clang_cxx_path = shutil.which(f'clang++-{version}')
if clang_c_path is not None and clang_cxx_path is not None and (force == NONE or force == CLANG):