Split LetDeclaration into VariableDeclaration and FunctionDeclaration

We only generate VariableDeclaration when we're absolutely sure it is a
variable.
This commit is contained in:
Sam Vervaeck 2024-03-09 12:51:35 +01:00
parent 2bfd88b05f
commit 719dbfcad4
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
12 changed files with 1156 additions and 564 deletions

View file

@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.10) cmake_minimum_required(VERSION 3.10)
project(Bolt CXX) project(Bolt C CXX)
set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD 20)
@ -17,6 +17,8 @@ if (CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo" OR CMAKE_BUILD_TYPE STREQUAL "Deb
set(BOLT_DEBUG ON) set(BOLT_DEBUG ON)
endif() endif()
find_package(LLVM 18.1.0 REQUIRED)
add_library( add_library(
BoltCore BoltCore
#src/Text.cc #src/Text.cc
@ -28,6 +30,7 @@ add_library(
src/Types.cc src/Types.cc
src/Checker.cc src/Checker.cc
src/Evaluator.cc src/Evaluator.cc
src/Scope.cc
) )
target_link_directories( target_link_directories(
BoltCore BoltCore
@ -61,6 +64,22 @@ target_link_libraries(
icuuc icuuc
) )
add_library(
BoltLLVM
src/LLVMCodeGen.cc
)
llvm_map_components_to_libnames(llvm_libs support core irreader)
target_include_directories(BoltLLVM PRIVATE ${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
target_compile_definitions(BoltLLVM PRIVATE ${LLVM_DEFINITIONS_LIST})
target_link_libraries(
BoltLLVM
PUBLIC
BoltCore
${llvm_libs}
)
add_executable( add_executable(
bolt bolt
src/main.cc src/main.cc
@ -69,6 +88,7 @@ target_link_libraries(
bolt bolt
PUBLIC PUBLIC
BoltCore BoltCore
BoltLLVM
) )
if (BOLT_ENABLE_TESTS) if (BOLT_ENABLE_TESTS)

View file

@ -1,11 +1,13 @@
#ifndef BOLT_CST_HPP #ifndef BOLT_CST_HPP
#define BOLT_CST_HPP #define BOLT_CST_HPP
#include <cstdlib>
#include <limits> #include <limits>
#include <unordered_map> #include <unordered_map>
#include <variant> #include <variant>
#include <vector> #include <vector>
#include "bolt/Common.hpp"
#include "zen/config.hpp" #include "zen/config.hpp"
#include "bolt/Integer.hpp" #include "bolt/Integer.hpp"
@ -172,7 +174,11 @@ enum class NodeKind {
Parameter, Parameter,
LetBlockBody, LetBlockBody,
LetExprBody, LetExprBody,
LetDeclaration, PrefixFunctionDeclaration,
InfixFunctionDeclaration,
SuffixFunctionDeclaration,
NamedFunctionDeclaration,
VariableDeclaration,
RecordDeclarationField, RecordDeclarationField,
RecordDeclaration, RecordDeclaration,
VariantDeclaration, VariantDeclaration,
@ -364,31 +370,6 @@ public:
}; };
/// Any node that can be used as an operator
///
/// This includes the following nodes:
/// - VBar
/// - CustomOperator
using Operator = Token;
/// Any node that can be used as a kind of identifier.
///
/// This includes the following nodes:
/// - Identifier
/// - IdentifierAlt
/// - WrappedOperator
using Symbol = Node;
inline bool isSymbol(const Node* N) {
return N->getKind() == NodeKind::Identifier
|| N->getKind() == NodeKind::IdentifierAlt
|| N->getKind() == NodeKind::WrappedOperator;
}
/// Get the text that is actually represented by a symbol, without all the
/// syntactic sugar.
ByteString getCanonicalText(const Symbol* N);
class Equals : public Token { class Equals : public Token {
public: public:
@ -903,6 +884,8 @@ public:
std::string getText() const override; std::string getText() const override;
std::string getCanonicalText() const;
static bool classof(const Node* N) { static bool classof(const Node* N) {
return N->getKind() == NodeKind::CustomOperator; return N->getKind() == NodeKind::CustomOperator;
} }
@ -935,6 +918,8 @@ public:
std::string getText() const override; std::string getText() const override;
ByteString getCanonicalText() const;
bool isTypeVar() const; bool isTypeVar() const;
static bool classof(const Node* N) { static bool classof(const Node* N) {
@ -953,6 +938,8 @@ public:
std::string getText() const override; std::string getText() const override;
ByteString getCanonicalText() const;
static bool classof(const Node* N) { static bool classof(const Node* N) {
return N->getKind() == NodeKind::IdentifierAlt; return N->getKind() == NodeKind::IdentifierAlt;
} }
@ -1021,6 +1008,182 @@ public:
}; };
/// Base node for things that can be used as an operator
///
/// This includes the following nodes:
/// - VBar
/// - CustomOperator
class Operator {
Node* N;
Operator(Node* N):
N(N) {}
public:
Operator() {}
Operator(VBar* N):
N(N) {}
Operator(CustomOperator* N):
N(N) {}
static Operator from_raw_node(Node* N) {
ZEN_ASSERT(isa<Operator>(N));
return N;
}
inline NodeKind getKind() const {
return N->getKind();
}
inline bool isVBar() const {
return N->getKind() == NodeKind::VBar;
}
inline bool isCustomOperator() const {
return N->getKind() == NodeKind::CustomOperator;
}
VBar* asVBar() const {
return static_cast<VBar*>(N);
}
CustomOperator* asCustomOperator() const {
return static_cast<CustomOperator*>(N);
}
operator Node*() const {
return N;
}
/// Get the name that is actually represented by an operator, without all the
/// syntactic sugar.
virtual ByteString getCanonicalText() const;
Token* getFirstToken() const;
Token* getLastToken() const;
inline static bool classof(const Node* N) {
return N->getKind() == NodeKind::VBar
|| N->getKind() == NodeKind::CustomOperator;
}
};
class WrappedOperator : public Node {
public:
class LParen* LParen;
Operator Op;
class RParen* RParen;
WrappedOperator(
class LParen* LParen,
Operator Operator,
class RParen* RParen
): Node(NodeKind::WrappedOperator),
LParen(LParen),
Op(Operator),
RParen(RParen) {}
inline Operator getOperator() const {
return Op;
}
ByteString getCanonicalText() const {
return Op.getCanonicalText();
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::WrappedOperator;
}
};
/// Base node for things that can be used as a symbol
///
/// This includes the following nodes:
/// - WrappedOperator
/// - Identifier
/// - IdentifierAlt
class Symbol {
Node* N;
Symbol(Node* N):
N(N) {}
public:
Symbol() {}
Symbol(WrappedOperator* N):
N(N) {}
Symbol(Identifier* N):
N(N) {}
Symbol(IdentifierAlt* N):
N(N) {}
static Symbol from_raw_node(Node* N) {
ZEN_ASSERT(isa<Symbol>(N));
return N;
}
NodeKind getKind() const {
return N->getKind();
}
bool isWrappedOperator() const {
return N->getKind() == NodeKind::WrappedOperator;
}
bool isIdentifier() const {
return N->getKind() == NodeKind::Identifier;
}
bool isIdentifierAlt() const {
return N->getKind() == NodeKind::IdentifierAlt;
}
IdentifierAlt* asIdentifierAlt() const {
return cast<IdentifierAlt>(N);
}
Identifier* asIdentifier() const {
return cast<Identifier>(N);
}
WrappedOperator* asWrappedOperator() const {
return cast<WrappedOperator>(N);
}
operator Node*() const {
return N;
}
/// Get the name that is actually represented by a symbol, without all the
/// syntactic sugar.
ByteString getCanonicalText() const;
Token* getFirstToken() const;
Token* getLastToken() const;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::Identifier
|| N->getKind() == NodeKind::IdentifierAlt
|| N->getKind() == NodeKind::WrappedOperator;
}
};
class Annotation : public Node { class Annotation : public Node {
public: public:
@ -1353,31 +1516,6 @@ public:
}; };
class WrappedOperator : public Symbol {
public:
class LParen* LParen;
Token* Op;
class RParen* RParen;
WrappedOperator(
class LParen* LParen,
Token* Operator,
class RParen* RParen
): Symbol(NodeKind::WrappedOperator),
LParen(LParen),
Op(Operator),
RParen(RParen) {}
inline Token* getOperator() const {
return Op;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class Pattern : public Node { class Pattern : public Node {
protected: protected:
@ -1390,10 +1528,10 @@ protected:
class BindPattern : public Pattern { class BindPattern : public Pattern {
public: public:
Symbol* Name; Identifier* Name;
BindPattern( BindPattern(
Symbol* Name Identifier* Name
): Pattern(NodeKind::BindPattern), ): Pattern(NodeKind::BindPattern),
Name(Name) {} Name(Name) {}
@ -1608,11 +1746,11 @@ class ReferenceExpression : public Expression {
public: public:
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath; std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath;
Symbol* Name; Symbol Name;
inline ReferenceExpression( inline ReferenceExpression(
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath, std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath,
Symbol* Name Symbol Name
): Expression(NodeKind::ReferenceExpression), ): Expression(NodeKind::ReferenceExpression),
ModulePath(ModulePath), ModulePath(ModulePath),
Name(Name) {} Name(Name) {}
@ -1620,13 +1758,13 @@ public:
inline ReferenceExpression( inline ReferenceExpression(
std::vector<Annotation*> Annotations, std::vector<Annotation*> Annotations,
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath, std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath,
Symbol* Name Symbol Name
): Expression(NodeKind::ReferenceExpression, Annotations), ): Expression(NodeKind::ReferenceExpression, Annotations),
ModulePath(ModulePath), ModulePath(ModulePath),
Name(Name) {} Name(Name) {}
inline ByteString getNameAsString() const noexcept { inline ByteString getNameAsString() const noexcept {
return getCanonicalText(Name); return Name.getCanonicalText();
} }
Token* getFirstToken() const override; Token* getFirstToken() const override;
@ -1864,12 +2002,12 @@ class InfixExpression : public Expression {
public: public:
Expression* Left; Expression* Left;
Token* Operator; Operator Operator;
Expression* Right; Expression* Right;
inline InfixExpression( inline InfixExpression(
Expression* Left, Expression* Left,
Token* Operator, class Operator Operator,
Expression* Right Expression* Right
): Expression(NodeKind::InfixExpression), ): Expression(NodeKind::InfixExpression),
Left(Left), Left(Left),
@ -1879,7 +2017,7 @@ public:
inline InfixExpression( inline InfixExpression(
std::vector<Annotation*> Annotations, std::vector<Annotation*> Annotations,
Expression* Left, Expression* Left,
Token* Operator, class Operator Operator,
Expression* Right Expression* Right
): Expression(NodeKind::InfixExpression, Annotations), ): Expression(NodeKind::InfixExpression, Annotations),
Left(Left), Left(Left),
@ -2110,6 +2248,10 @@ public:
Token* getFirstToken() const override; Token* getFirstToken() const override;
Token* getLastToken() const override; Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::Parameter;
}
}; };
class LetBody : public Node { class LetBody : public Node {
@ -2155,7 +2297,7 @@ public:
}; };
class LetDeclaration : public TypedNode, public AnnotationContainer { class FunctionDeclaration : public TypedNode, public AnnotationContainer {
Scope* TheScope = nullptr; Scope* TheScope = nullptr;
@ -2165,63 +2307,34 @@ public:
bool Visited = false; bool Visited = false;
InferContext* Ctx; InferContext* Ctx;
class PubKeyword* PubKeyword; FunctionDeclaration(NodeKind Kind, std::vector<Annotation*> Annotations = {}):
class ForeignKeyword* ForeignKeyword; TypedNode(Kind), AnnotationContainer(Annotations) {}
class LetKeyword* LetKeyword;
class MutKeyword* MutKeyword;
class Pattern* Pattern;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
LetDeclaration( virtual bool isPublic() const = 0;
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
class MutKeyword* MutKeyword,
class Pattern* Pattern,
std::vector<Parameter*> Params,
class TypeAssert* TypeAssert,
LetBody* Body
): TypedNode(NodeKind::LetDeclaration),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
MutKeyword(MutKeyword),
Pattern(Pattern),
Params(Params),
TypeAssert(TypeAssert),
Body(Body) {}
LetDeclaration( virtual bool isForeign() const = 0;
std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword, virtual ByteString getNameAsString() const = 0;
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword, virtual std::vector<Parameter*> getParams() const = 0;
class MutKeyword* MutKeyword,
class Pattern* Pattern, virtual TypeAssert* getTypeAssert() const = 0;
std::vector<Parameter*> Params,
class TypeAssert* TypeAssert, bool hasTypeAssert() const {
LetBody* Body return getTypeAssert();
): TypedNode(NodeKind::LetDeclaration), }
AnnotationContainer(Annotations),
PubKeyword(PubKeyword), virtual LetBody* getBody() const = 0;
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword), bool hasBody() const {
MutKeyword(MutKeyword), return getBody();
Pattern(Pattern), }
Params(Params),
TypeAssert(TypeAssert),
Body(Body) {}
inline Scope* getScope() override { inline Scope* getScope() override {
if (isFunction()) { if (TheScope == nullptr) {
if (TheScope == nullptr) { TheScope = new Scope(this);
TheScope = new Scope(this);
}
return TheScope;
} }
return Parent->getScope(); return TheScope;
} }
bool isInstance() const noexcept { bool isInstance() const noexcept {
@ -2232,33 +2345,310 @@ public:
return Parent->getKind() == NodeKind::ClassDeclaration; return Parent->getKind() == NodeKind::ClassDeclaration;
} }
bool isSignature() const noexcept { static bool classof(const Node* N) {
return ForeignKeyword; return N->getKind() == NodeKind::PrefixFunctionDeclaration
|| N->getKind() == NodeKind::InfixFunctionDeclaration
|| N->getKind() == NodeKind::SuffixFunctionDeclaration
|| N->getKind() == NodeKind::NamedFunctionDeclaration;
} }
bool isVariable() const noexcept { };
// Variables in classes and instances are never possible, so we reflect this by excluding them here.
return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body); class PrefixFunctionDeclaration : public FunctionDeclaration {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
class Operator Name;
Parameter* Param;
class TypeAssert* TypeAssert;
LetBody* Body;
PrefixFunctionDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
Operator Name,
Parameter* Param,
class TypeAssert* TypeAssert,
LetBody* Body
): FunctionDeclaration(NodeKind::PrefixFunctionDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
Name(Name),
Param(Param),
TypeAssert(TypeAssert),
Body(Body) {}
bool isPublic() const override {
return PubKeyword != nullptr;
} }
bool isFunction() const noexcept { bool isForeign() const override {
return !isSignature() && !isVariable(); return ForeignKeyword != nullptr;
} }
Symbol* getName() const noexcept { ByteString getNameAsString() const override {
ZEN_ASSERT(Pattern->getKind() == NodeKind::BindPattern); return Name.getCanonicalText();
return static_cast<BindPattern*>(Pattern)->Name;
} }
ByteString getNameAsString() const noexcept { std::vector<Parameter*> getParams() const override {
return getCanonicalText(getName()); return { Param };
}
class TypeAssert* getTypeAssert() const override {
return TypeAssert;
}
LetBody* getBody() const override {
return Body;
} }
Token* getFirstToken() const override; Token* getFirstToken() const override;
Token* getLastToken() const override; Token* getLastToken() const override;
static bool classof(const Node* N) { static bool classof(const Node* N) {
return N->getKind() == NodeKind::LetDeclaration; return N->getKind() == NodeKind::PrefixFunctionDeclaration;
}
};
class SuffixFunctionDeclaration : public FunctionDeclaration {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
Parameter* Param;
class Operator Name;
class TypeAssert* TypeAssert;
LetBody* Body;
SuffixFunctionDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
Parameter* Param,
Operator Name,
class TypeAssert* TypeAssert,
LetBody* Body
): FunctionDeclaration(NodeKind::SuffixFunctionDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
Name(Name),
Param(Param),
TypeAssert(TypeAssert),
Body(Body) {}
bool isPublic() const override {
return PubKeyword != nullptr;
}
bool isForeign() const override {
return ForeignKeyword != nullptr;
}
ByteString getNameAsString() const override {
return Name.getCanonicalText();
}
std::vector<Parameter*> getParams() const override {
return { Param };
}
class TypeAssert* getTypeAssert() const override {
return TypeAssert;
}
LetBody* getBody() const override {
return Body;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::SuffixFunctionDeclaration;
}
};
class InfixFunctionDeclaration : public FunctionDeclaration {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
Parameter* Left;
class Operator Name;
Parameter* Right;
class TypeAssert* TypeAssert;
LetBody* Body;
InfixFunctionDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
Parameter* Left,
class Operator Name,
Parameter* Right,
class TypeAssert* TypeAssert,
LetBody* Body
): FunctionDeclaration(NodeKind::InfixFunctionDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
Left(Left),
Name(Name),
Right(Right),
TypeAssert(TypeAssert),
Body(Body) {}
bool isPublic() const override {
return PubKeyword != nullptr;
}
bool isForeign() const override {
return ForeignKeyword != nullptr;
}
ByteString getNameAsString() const override {
return Name.getCanonicalText();
}
std::vector<Parameter*> getParams() const override {
return { Left, Right };
}
class TypeAssert* getTypeAssert() const override {
return TypeAssert;
}
LetBody* getBody() const override {
return Body;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::InfixFunctionDeclaration;
}
};
class NamedFunctionDeclaration : public FunctionDeclaration {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
class Symbol Name;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
NamedFunctionDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
class Symbol Name,
std::vector<Parameter*> Params,
class TypeAssert* TypeAssert,
LetBody* Body
): FunctionDeclaration(NodeKind::NamedFunctionDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
Name(Name),
Params(Params),
TypeAssert(TypeAssert),
Body(Body) {}
bool isPublic() const override {
return PubKeyword != nullptr;
}
bool isForeign() const override {
return ForeignKeyword != nullptr;
}
ByteString getNameAsString() const override {
return Name.getCanonicalText();
}
std::vector<Parameter*> getParams() const override {
return Params;
}
class TypeAssert* getTypeAssert() const override {
return TypeAssert;
}
LetBody* getBody() const override {
return Body;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::NamedFunctionDeclaration;
}
};
class VariableDeclaration : public TypedNode, public AnnotationContainer {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
class MutKeyword* MutKeyword;
class Pattern* Pattern;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
VariableDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
class MutKeyword* MutKeyword,
class Pattern* Pattern,
class TypeAssert* TypeAssert,
LetBody* Body
): TypedNode(NodeKind::VariableDeclaration),
AnnotationContainer(Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
MutKeyword(MutKeyword),
Pattern(Pattern),
TypeAssert(TypeAssert),
Body(Body) {}
Symbol getName() const noexcept {
ZEN_ASSERT(Pattern->getKind() == NodeKind::BindPattern);
return static_cast<BindPattern*>(Pattern)->Name;
}
ByteString getNameAsString() const noexcept {
return getName().getCanonicalText();
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::VariableDeclaration;
} }
}; };
@ -2555,7 +2945,11 @@ template<> inline NodeKind getNodeType<TypeAssert>() { return NodeKind::TypeAsse
template<> inline NodeKind getNodeType<Parameter>() { return NodeKind::Parameter; } template<> inline NodeKind getNodeType<Parameter>() { return NodeKind::Parameter; }
template<> inline NodeKind getNodeType<LetBlockBody>() { return NodeKind::LetBlockBody; } template<> inline NodeKind getNodeType<LetBlockBody>() { return NodeKind::LetBlockBody; }
template<> inline NodeKind getNodeType<LetExprBody>() { return NodeKind::LetExprBody; } template<> inline NodeKind getNodeType<LetExprBody>() { return NodeKind::LetExprBody; }
template<> inline NodeKind getNodeType<LetDeclaration>() { return NodeKind::LetDeclaration; } template<> inline NodeKind getNodeType<PrefixFunctionDeclaration>() { return NodeKind::PrefixFunctionDeclaration; }
template<> inline NodeKind getNodeType<InfixFunctionDeclaration>() { return NodeKind::InfixFunctionDeclaration; }
template<> inline NodeKind getNodeType<SuffixFunctionDeclaration>() { return NodeKind::SuffixFunctionDeclaration; }
template<> inline NodeKind getNodeType<NamedFunctionDeclaration()>() { return NodeKind::NamedFunctionDeclaration; }
template<> inline NodeKind getNodeType<VariableDeclaration>() { return NodeKind::VariableDeclaration; }
template<> inline NodeKind getNodeType<RecordDeclarationField>() { return NodeKind::RecordDeclarationField; } template<> inline NodeKind getNodeType<RecordDeclarationField>() { return NodeKind::RecordDeclarationField; }
template<> inline NodeKind getNodeType<RecordDeclaration>() { return NodeKind::RecordDeclaration; } template<> inline NodeKind getNodeType<RecordDeclaration>() { return NodeKind::RecordDeclaration; }
template<> inline NodeKind getNodeType<ClassDeclaration>() { return NodeKind::ClassDeclaration; } template<> inline NodeKind getNodeType<ClassDeclaration>() { return NodeKind::ClassDeclaration; }

View file

@ -15,8 +15,8 @@ public:
void visit(Node* N) { void visit(Node* N) {
#define BOLT_GEN_CASE(name) \ #define BOLT_GEN_CASE(name) \
case NodeKind::name: \ case NodeKind::name: \
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N)); return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
switch (N->getKind()) { switch (N->getKind()) {
BOLT_GEN_CASE(VBar) BOLT_GEN_CASE(VBar)
@ -104,7 +104,11 @@ public:
BOLT_GEN_CASE(Parameter) BOLT_GEN_CASE(Parameter)
BOLT_GEN_CASE(LetBlockBody) BOLT_GEN_CASE(LetBlockBody)
BOLT_GEN_CASE(LetExprBody) BOLT_GEN_CASE(LetExprBody)
BOLT_GEN_CASE(LetDeclaration) BOLT_GEN_CASE(PrefixFunctionDeclaration)
BOLT_GEN_CASE(InfixFunctionDeclaration)
BOLT_GEN_CASE(SuffixFunctionDeclaration)
BOLT_GEN_CASE(NamedFunctionDeclaration)
BOLT_GEN_CASE(VariableDeclaration)
BOLT_GEN_CASE(RecordDeclaration) BOLT_GEN_CASE(RecordDeclaration)
BOLT_GEN_CASE(RecordDeclarationField) BOLT_GEN_CASE(RecordDeclarationField)
BOLT_GEN_CASE(VariantDeclaration) BOLT_GEN_CASE(VariantDeclaration)
@ -116,6 +120,35 @@ public:
} }
} }
void dispatchSymbol(const Symbol& S) {
switch (S.getKind()) {
case NodeKind::Identifier:
visit(S.asIdentifier());
break;
case NodeKind::IdentifierAlt:
visit(S.asIdentifierAlt());
break;
case NodeKind::WrappedOperator:
visit(S.asWrappedOperator());
break;
default:
ZEN_UNREACHABLE
}
}
void dispatchOperator(const Operator& O) {
switch (O.getKind()) {
case NodeKind::VBar:
visit(O.asVBar());
break;
case NodeKind::CustomOperator:
visit(O.asCustomOperator());
break;
default:
ZEN_UNREACHABLE
}
}
protected: protected:
void visitNode(Node* N) { void visitNode(Node* N) {
@ -494,7 +527,27 @@ protected:
static_cast<D*>(this)->visitLetBody(N); static_cast<D*>(this)->visitLetBody(N);
} }
void visitLetDeclaration(LetDeclaration* N) { void visitFunctionDeclaration(FunctionDeclaration* N) {
static_cast<D*>(this)->visitNode(N);
}
void visitPrefixFunctionDeclaration(PrefixFunctionDeclaration* N) {
static_cast<D*>(this)->visitFunctionDeclaration(N);
}
void visitInfixFunctionDeclaration(InfixFunctionDeclaration* N) {
static_cast<D*>(this)->visitFunctionDeclaration(N);
}
void visitSuffixFunctionDeclaration(SuffixFunctionDeclaration* N) {
static_cast<D*>(this)->visitFunctionDeclaration(N);
}
void visitNamedFunctionDeclaration(NamedFunctionDeclaration* N) {
static_cast<D*>(this)->visitFunctionDeclaration(N);
}
void visitVariableDeclaration(VariableDeclaration* N) {
static_cast<D*>(this)->visitNode(N); static_cast<D*>(this)->visitNode(N);
} }
@ -629,7 +682,11 @@ public:
BOLT_GEN_CHILD_CASE(Parameter) BOLT_GEN_CHILD_CASE(Parameter)
BOLT_GEN_CHILD_CASE(LetBlockBody) BOLT_GEN_CHILD_CASE(LetBlockBody)
BOLT_GEN_CHILD_CASE(LetExprBody) BOLT_GEN_CHILD_CASE(LetExprBody)
BOLT_GEN_CHILD_CASE(LetDeclaration) BOLT_GEN_CHILD_CASE(PrefixFunctionDeclaration)
BOLT_GEN_CHILD_CASE(InfixFunctionDeclaration)
BOLT_GEN_CHILD_CASE(SuffixFunctionDeclaration)
BOLT_GEN_CHILD_CASE(NamedFunctionDeclaration)
BOLT_GEN_CHILD_CASE(VariableDeclaration)
BOLT_GEN_CHILD_CASE(RecordDeclaration) BOLT_GEN_CHILD_CASE(RecordDeclaration)
BOLT_GEN_CHILD_CASE(RecordDeclarationField) BOLT_GEN_CHILD_CASE(RecordDeclarationField)
BOLT_GEN_CHILD_CASE(VariantDeclaration) BOLT_GEN_CHILD_CASE(VariantDeclaration)
@ -642,6 +699,8 @@ public:
} }
#define BOLT_VISIT(node) static_cast<D*>(this)->visit(node) #define BOLT_VISIT(node) static_cast<D*>(this)->visit(node)
#define BOLT_VISIT_SYMBOL(node) static_cast<D*>(this)->dispatchSymbol(node)
#define BOLT_VISIT_OPERATOR(node) static_cast<D*>(this)->dispatchOperator(node)
void visitEachChild(VBar* N) { void visitEachChild(VBar* N) {
} }
@ -771,7 +830,7 @@ public:
void visitEachChild(WrappedOperator* N) { void visitEachChild(WrappedOperator* N) {
BOLT_VISIT(N->LParen); BOLT_VISIT(N->LParen);
BOLT_VISIT(N->Op); BOLT_VISIT_OPERATOR(N->Op);
BOLT_VISIT(N->RParen); BOLT_VISIT(N->RParen);
} }
@ -972,7 +1031,7 @@ public:
BOLT_VISIT(Name); BOLT_VISIT(Name);
BOLT_VISIT(Dot); BOLT_VISIT(Dot);
} }
BOLT_VISIT(N->Name); BOLT_VISIT_SYMBOL(N->Name);
} }
void visitEachChild(MatchCase* N) { void visitEachChild(MatchCase* N) {
@ -1049,7 +1108,7 @@ public:
BOLT_VISIT(A); BOLT_VISIT(A);
} }
BOLT_VISIT(N->Left); BOLT_VISIT(N->Left);
BOLT_VISIT(N->Operator); BOLT_VISIT_OPERATOR(N->Operator);
BOLT_VISIT(N->Right); BOLT_VISIT(N->Right);
} }
@ -1140,7 +1199,7 @@ public:
BOLT_VISIT(N->Expression); BOLT_VISIT(N->Expression);
} }
void visitEachChild(LetDeclaration* N) { void visitEachChild(PrefixFunctionDeclaration* N) {
for (auto A: N->Annotations) { for (auto A: N->Annotations) {
BOLT_VISIT(A); BOLT_VISIT(A);
} }
@ -1151,6 +1210,96 @@ public:
BOLT_VISIT(N->ForeignKeyword); BOLT_VISIT(N->ForeignKeyword);
} }
BOLT_VISIT(N->LetKeyword); BOLT_VISIT(N->LetKeyword);
BOLT_VISIT(N->Param);
BOLT_VISIT_OPERATOR(N->Name);
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(InfixFunctionDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
if (N->ForeignKeyword) {
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
BOLT_VISIT(N->Left);
BOLT_VISIT_OPERATOR(N->Name);
BOLT_VISIT(N->Right);
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(SuffixFunctionDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
if (N->ForeignKeyword) {
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
BOLT_VISIT_OPERATOR(N->Name);
BOLT_VISIT(N->Param);
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(NamedFunctionDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
if (N->ForeignKeyword) {
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
BOLT_VISIT_SYMBOL(N->Name);
for (auto Param: N->Params) {
BOLT_VISIT(Param);
}
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(VariableDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
if (N->ForeignKeyword) {
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
if (N->MutKeyword) {
BOLT_VISIT(N->MutKeyword);
}
BOLT_VISIT(N->Pattern); BOLT_VISIT(N->Pattern);
for (auto Param: N->Params) { for (auto Param: N->Params) {
BOLT_VISIT(Param); BOLT_VISIT(Param);

View file

@ -241,7 +241,7 @@ class Checker {
/// Type inference /// Type inference
void forwardDeclare(Node* Node); void forwardDeclare(Node* Node);
void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); void forwardDeclareFunctionDeclaration(FunctionDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
Type* inferExpression(Expression* Expression); Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true); Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true);
@ -249,7 +249,7 @@ class Checker {
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);
void infer(Node* node); void infer(Node* node);
void inferFunctionDeclaration(LetDeclaration* N); void inferFunctionDeclaration(FunctionDeclaration* N);
void inferConstraintExpression(ConstraintExpression* C); void inferConstraintExpression(ConstraintExpression* C);
/// Factory methods /// Factory methods

View file

@ -1,6 +1,8 @@
#pragma once #pragma once
#include "zen/config.hpp"
namespace bolt { namespace bolt {
class LanguageConfig { class LanguageConfig {

View file

@ -29,7 +29,7 @@ class Value {
union { union {
ByteString S; ByteString S;
Integer I; Integer I;
LetDeclaration* D; FunctionDeclaration* D;
NativeFunction F; NativeFunction F;
Tuple T; Tuple T;
}; };
@ -45,7 +45,7 @@ public:
Value(Integer I): Value(Integer I):
Kind(ValueKind::Integer), I(I) {} Kind(ValueKind::Integer), I(I) {}
Value(LetDeclaration* D): Value(FunctionDeclaration* D):
Kind(ValueKind::SourceFunction), D(D) {} Kind(ValueKind::SourceFunction), D(D) {}
Value(NativeFunction F): Value(NativeFunction F):
@ -67,7 +67,7 @@ public:
new (&I) Tuple(V.T); new (&I) Tuple(V.T);
break; break;
case ValueKind::SourceFunction: case ValueKind::SourceFunction:
new (&D) LetDeclaration*(V.D); new (&D) FunctionDeclaration*(V.D);
break; break;
case ValueKind::NativeFunction: case ValueKind::NativeFunction:
new (&F) NativeFunction(V.F); new (&F) NativeFunction(V.F);
@ -90,7 +90,7 @@ public:
new (&I) Tuple(Other.T); new (&I) Tuple(Other.T);
break; break;
case ValueKind::SourceFunction: case ValueKind::SourceFunction:
new (&D) LetDeclaration*(Other.D); new (&D) FunctionDeclaration*(Other.D);
break; break;
case ValueKind::NativeFunction: case ValueKind::NativeFunction:
new (&F) NativeFunction(Other.F); new (&F) NativeFunction(Other.F);
@ -112,7 +112,7 @@ public:
return S; return S;
} }
inline LetDeclaration* getDeclaration() { inline FunctionDeclaration* getDeclaration() {
ZEN_ASSERT(Kind == ValueKind::SourceFunction); ZEN_ASSERT(Kind == ValueKind::SourceFunction);
return D; return D;
} }

View file

@ -130,7 +130,7 @@ public:
Node* parseLetBodyElement(); Node* parseLetBodyElement();
LetDeclaration* parseLetDeclaration(); Node* parseLetDeclaration();
Node* parseClassElement(); Node* parseClassElement();

View file

@ -59,216 +59,6 @@ ByteString TextFile::getText() const {
return Text; return Text;
} }
Scope::Scope(Node* Source):
Source(Source) {
scan(Source);
}
void Scope::addSymbol(ByteString Name, Node* Decl, SymbolKind Kind) {
Mapping.emplace(Name, std::make_tuple(Decl, Kind));
}
void Scope::scan(Node* X) {
switch (X->getKind()) {
case NodeKind::SourceFile:
{
auto File = static_cast<SourceFile*>(X);
for (auto Element: File->Elements) {
scanChild(Element);
}
break;
}
case NodeKind::MatchCase:
{
auto Case = static_cast<MatchCase*>(X);
visitPattern(Case->Pattern, Case);
break;
}
case NodeKind::LetDeclaration:
{
auto Decl = static_cast<LetDeclaration*>(X);
ZEN_ASSERT(Decl->isFunction());
for (auto Param: Decl->Params) {
visitPattern(Param->Pattern, Param);
}
if (Decl->Body) {
scanChild(Decl->Body);
}
break;
}
default:
ZEN_UNREACHABLE
}
}
void Scope::scanChild(Node* X) {
switch (X->getKind()) {
case NodeKind::LetExprBody:
case NodeKind::ExpressionStatement:
case NodeKind::IfStatement:
case NodeKind::ReturnStatement:
break;
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(X);
for (auto Element: Block->Elements) {
scanChild(Element);
}
break;
}
case NodeKind::InstanceDeclaration:
// We ignore let-declarations inside instance-declarations for now
break;
case NodeKind::ClassDeclaration:
{
auto Decl = static_cast<ClassDeclaration*>(X);
addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Class);
for (auto Element: Decl->Elements) {
scanChild(Element);
}
break;
}
case NodeKind::LetDeclaration:
{
auto Decl = static_cast<LetDeclaration*>(X);
// No matter if it is a function or a variable, by visiting the pattern
// we add all relevant bindings to the current scope.
visitPattern(Decl->Pattern, Decl);
break;
}
case NodeKind::RecordDeclaration:
{
auto Decl = static_cast<RecordDeclaration*>(X);
addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type);
break;
}
case NodeKind::VariantDeclaration:
{
auto Decl = static_cast<VariantDeclaration*>(X);
addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type);
for (auto Member: Decl->Members) {
switch (Member->getKind()) {
case NodeKind::TupleVariantDeclarationMember:
{
auto T = static_cast<TupleVariantDeclarationMember*>(Member);
addSymbol(getCanonicalText(T->Name), Decl, SymbolKind::Constructor);
break;
}
case NodeKind::RecordVariantDeclarationMember:
{
auto R = static_cast<RecordVariantDeclarationMember*>(Member);
addSymbol(getCanonicalText(R->Name), Decl, SymbolKind::Constructor);
break;
}
default:
ZEN_UNREACHABLE
}
}
break;
}
default:
ZEN_UNREACHABLE
}
}
void Scope::visitPattern(Pattern* X, Node* Decl) {
switch (X->getKind()) {
case NodeKind::BindPattern:
{
auto Y = static_cast<BindPattern*>(X);
addSymbol(getCanonicalText(Y->Name), Decl, SymbolKind::Var);
break;
}
case NodeKind::RecordPattern:
{
auto Y = static_cast<RecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) {
visitPattern(Field->Pattern, Decl);
} else if (Field->Name) {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
}
}
break;
}
case NodeKind::NamedRecordPattern:
{
auto Y = static_cast<NamedRecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) {
visitPattern(Field->Pattern, Decl);
} else if (Field->Name) {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
}
}
break;
}
case NodeKind::NamedTuplePattern:
{
auto Y = static_cast<NamedTuplePattern*>(X);
for (auto P: Y->Patterns) {
visitPattern(P, Decl);
}
break;
}
case NodeKind::NestedPattern:
{
auto Y = static_cast<NestedPattern*>(X);
visitPattern(Y->P, Decl);
break;
}
case NodeKind::TuplePattern:
{
auto Y = static_cast<TuplePattern*>(X);
for (auto [Element, Comma]: Y->Elements) {
visitPattern(Element, Decl);
}
break;
}
case NodeKind::ListPattern:
{
auto Y = static_cast<ListPattern*>(X);
for (auto [Element, Separator]: Y->Elements) {
visitPattern(Element, Decl);
}
break;
}
case NodeKind::LiteralPattern:
break;
default:
ZEN_UNREACHABLE
}
}
Node* Scope::lookupDirect(SymbolPath Path, SymbolKind Kind) {
ZEN_ASSERT(Path.Modules.empty());
auto Match = Mapping.find(Path.Name);
if (Match != Mapping.end() && std::get<1>(Match->second) == Kind) {
return std::get<0>(Match->second);
}
return nullptr;
}
Node* Scope::lookup(SymbolPath Path, SymbolKind Kind) {
ZEN_ASSERT(Path.Modules.empty());
auto Curr = this;
do {
auto Found = Curr->lookupDirect(Path, Kind);
if (Found) {
return Found;
}
Curr = Curr->getParentScope();
} while (Curr != nullptr);
return nullptr;
}
Scope* Scope::getParentScope() {
if (Source->Parent == nullptr) {
return nullptr;
}
return Source->Parent->getScope();
}
const SourceFile* Node::getSourceFile() const { const SourceFile* Node::getSourceFile() const {
const Node* CurrNode = this; const Node* CurrNode = this;
for (;;) { for (;;) {
@ -503,29 +293,11 @@ Token* WrappedOperator::getLastToken() const {
} }
Token* BindPattern::getFirstToken() const { Token* BindPattern::getFirstToken() const {
switch (Name->getKind()) { return Name;
case NodeKind::Identifier:
return static_cast<Identifier*>(Name);
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(Name);
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(Name)->LParen;
default:
ZEN_UNREACHABLE
}
} }
Token* BindPattern::getLastToken() const { Token* BindPattern::getLastToken() const {
switch (Name->getKind()) { return Name;
case NodeKind::Identifier:
return static_cast<Identifier*>(Name);
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(Name);
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(Name)->RParen;
default:
ZEN_UNREACHABLE
}
} }
Token* LiteralPattern::getFirstToken() const { Token* LiteralPattern::getFirstToken() const {
@ -608,26 +380,26 @@ Token* ReferenceExpression::getFirstToken() const {
if (!ModulePath.empty()) { if (!ModulePath.empty()) {
return std::get<0>(ModulePath.front()); return std::get<0>(ModulePath.front());
} }
switch (Name->getKind()) { switch (Name.getKind()) {
case NodeKind::Identifier: case NodeKind::Identifier:
return static_cast<Identifier*>(Name); return Name.asIdentifier();
case NodeKind::IdentifierAlt: case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(Name); return Name.asIdentifierAlt();
case NodeKind::WrappedOperator: case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(Name)->LParen; return Name.asWrappedOperator()->getFirstToken();
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
} }
Token* ReferenceExpression::getLastToken() const { Token* ReferenceExpression::getLastToken() const {
switch (Name->getKind()) { switch (Name.getKind()) {
case NodeKind::Identifier: case NodeKind::Identifier:
return static_cast<Identifier*>(Name); return Name.asIdentifier();
case NodeKind::IdentifierAlt: case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(Name); return Name.asIdentifierAlt();
case NodeKind::WrappedOperator: case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(Name)->RParen; return Name.asWrappedOperator()->getLastToken();
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
@ -805,7 +577,7 @@ Token* LetExprBody::getLastToken() const {
return Expression->getLastToken(); return Expression->getLastToken();
} }
Token* LetDeclaration::getFirstToken() const { Token* PrefixFunctionDeclaration::getFirstToken() const {
if (PubKeyword) { if (PubKeyword) {
return PubKeyword; return PubKeyword;
} }
@ -815,17 +587,97 @@ Token* LetDeclaration::getFirstToken() const {
return LetKeyword; return LetKeyword;
} }
Token* LetDeclaration::getLastToken() const { Token* PrefixFunctionDeclaration::getLastToken() const {
if (Body) { if (Body) {
return Body->getLastToken(); return Body->getLastToken();
} }
if (TypeAssert) { if (TypeAssert) {
return TypeAssert->getLastToken(); return TypeAssert->getLastToken();
} }
if (Params.size()) { return Param->getLastToken();
}
Token* InfixFunctionDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
if (ForeignKeyword) {
return ForeignKeyword;
}
return LetKeyword;
}
Token* InfixFunctionDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
return Right->getLastToken();
}
Token* SuffixFunctionDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
if (ForeignKeyword) {
return ForeignKeyword;
}
return LetKeyword;
}
Token* SuffixFunctionDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
return Name.getLastToken();
}
Token* NamedFunctionDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
if (ForeignKeyword) {
return ForeignKeyword;
}
return LetKeyword;
}
Token* NamedFunctionDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
if (!Params.empty()) {
return Params.back()->getLastToken(); return Params.back()->getLastToken();
} }
return Name.getLastToken();
}
Token* VariableDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
if (ForeignKeyword) {
return ForeignKeyword;
}
return LetKeyword;
}
Token* VariableDeclaration::getLastToken() const {
return Pattern->getLastToken(); return Pattern->getLastToken();
if (TypeAssert) {
return TypeAssert->getLastToken();
}
if (Body) {
return Body->getLastToken();
}
} }
Token* RecordDeclarationField::getFirstToken() const { Token* RecordDeclarationField::getFirstToken() const {
@ -1093,23 +945,81 @@ std::string InstanceKeyword::getText() const {
return "instance"; return "instance";
} }
ByteString getCanonicalText(const Symbol* N) { ByteString Identifier::getCanonicalText() const {
return Text;
}
ByteString IdentifierAlt::getCanonicalText() const {
return Text;
}
ByteString CustomOperator::getCanonicalText() const {
return Text;
}
ByteString Symbol::getCanonicalText() const {
switch (N->getKind()) { switch (N->getKind()) {
case NodeKind::Identifier: case NodeKind::Identifier:
return static_cast<const Identifier*>(N)->Text; return static_cast<const Identifier*>(N)->getCanonicalText();
case NodeKind::IdentifierAlt: case NodeKind::IdentifierAlt:
return static_cast<const IdentifierAlt*>(N)->Text; return static_cast<const IdentifierAlt*>(N)->getCanonicalText();
case NodeKind::CustomOperator: case NodeKind::CustomOperator:
return static_cast<const CustomOperator*>(N)->Text; return static_cast<const CustomOperator*>(N)->getCanonicalText();
case NodeKind::VBar: case NodeKind::VBar:
return static_cast<const VBar*>(N)->getText(); return static_cast<const VBar*>(N)->getText();
case NodeKind::WrappedOperator: case NodeKind::WrappedOperator:
return static_cast<const WrappedOperator*>(N)->getOperator()->getText(); return static_cast<const WrappedOperator*>(N)->getCanonicalText();
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
} }
Token* Symbol::getFirstToken() const {
switch (N->getKind()) {
case NodeKind::Identifier:
return static_cast<Identifier*>(N);
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(N);
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(N)->getFirstToken();
default:
ZEN_UNREACHABLE
}
}
Token* Symbol::getLastToken() const {
switch (N->getKind()) {
case NodeKind::Identifier:
return static_cast<Identifier*>(N);
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(N);
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(N)->getLastToken();
default:
ZEN_UNREACHABLE
}
}
ByteString Operator::getCanonicalText() const {
switch (N->getKind()) {
case NodeKind::CustomOperator:
return static_cast<const CustomOperator*>(N)->getCanonicalText();
case NodeKind::VBar:
return static_cast<const VBar*>(N)->getText();
default:
ZEN_UNREACHABLE
}
}
Token* Operator::getFirstToken() const {
return static_cast<Token*>(N);
}
Token* Operator::getLastToken() const {
return static_cast<Token*>(N);
}
LiteralValue StringLiteral::getValue() { LiteralValue StringLiteral::getValue() {
return Text; return Text;
} }
@ -1121,9 +1031,9 @@ LiteralValue IntegerLiteral::getValue() {
SymbolPath ReferenceExpression::getSymbolPath() const { SymbolPath ReferenceExpression::getSymbolPath() const {
std::vector<ByteString> ModuleNames; std::vector<ByteString> ModuleNames;
for (auto [Name, Dot]: ModulePath) { for (auto [Name, Dot]: ModulePath) {
ModuleNames.push_back(getCanonicalText(Name)); ModuleNames.push_back(Name->getCanonicalText());
} }
return SymbolPath { ModuleNames, getCanonicalText(Name) }; return SymbolPath { ModuleNames, Name.getCanonicalText() };
} }
} }

View file

@ -251,9 +251,9 @@ void Checker::forwardDeclare(Node* X) {
inferTypeExpression(TE); inferTypeExpression(TE);
} }
auto Match = InstanceMap.find(getCanonicalText(Decl->Name)); auto Match = InstanceMap.find(Decl->Name->getCanonicalText());
if (Match == InstanceMap.end()) { if (Match == InstanceMap.end()) {
InstanceMap.emplace(getCanonicalText(Decl->Name), std::vector { Decl }); InstanceMap.emplace(Decl->Name->getCanonicalText(), std::vector { Decl });
} else { } else {
Match->second.push_back(Decl); Match->second.push_back(Decl);
} }
@ -265,13 +265,15 @@ void Checker::forwardDeclare(Node* X) {
break; break;
} }
case NodeKind::LetDeclaration: case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration:
case NodeKind::NamedFunctionDeclaration:
break;
case NodeKind::VariableDeclaration:
{ {
// Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression() auto Decl = static_cast<VariableDeclaration*>(X);
auto Decl = static_cast<LetDeclaration*>(X);
if (!Decl->isVariable()) {
break;
}
Type* Ty; Type* Ty;
if (Decl->TypeAssert) { if (Decl->TypeAssert) {
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression); Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression);
@ -290,13 +292,13 @@ void Checker::forwardDeclare(Node* X) {
std::vector<Type*> Vars; std::vector<Type*> Vars;
for (auto TE: Decl->TVs) { for (auto TE: Decl->TVs) {
auto TV = createRigidVar(getCanonicalText(TE->Name)); auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV); Decl->Ctx->TVs->emplace(TV);
Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
Vars.push_back(TV); Vars.push_back(TV);
} }
Type* Ty = createConType(getCanonicalText(Decl->Name)); Type* Ty = createConType(Decl->Name->getCanonicalText());
// Build the type that is actually returned by constructor functions // Build the type that is actually returned by constructor functions
auto RetTy = Ty; auto RetTy = Ty;
@ -305,7 +307,7 @@ void Checker::forwardDeclare(Node* X) {
} }
// Must be added early so we can create recursive types // Must be added early so we can create recursive types
Decl->Ctx->Parent->Env.add(getCanonicalText(Decl->Name), new Forall(Ty), SymKind::Type); Decl->Ctx->Parent->Env.add(Decl->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
for (auto Member: Decl->Members) { for (auto Member: Decl->Members) {
switch (Member->getKind()) { switch (Member->getKind()) {
@ -318,7 +320,7 @@ void Checker::forwardDeclare(Node* X) {
ParamTypes.push_back(inferTypeExpression(Element, false)); ParamTypes.push_back(inferTypeExpression(Element, false));
} }
Decl->Ctx->Parent->Env.add( Decl->Ctx->Parent->Env.add(
getCanonicalText(TupleMember->Name), TupleMember->Name->getCanonicalText(),
new Forall( new Forall(
Decl->Ctx->TVs, Decl->Ctx->TVs,
Decl->Ctx->Constraints, Decl->Ctx->Constraints,
@ -351,13 +353,13 @@ void Checker::forwardDeclare(Node* X) {
std::vector<Type*> Vars; std::vector<Type*> Vars;
for (auto TE: Decl->Vars) { for (auto TE: Decl->Vars) {
auto TV = createRigidVar(getCanonicalText(TE->Name)); auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV); Decl->Ctx->TVs->emplace(TV);
Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
Vars.push_back(TV); Vars.push_back(TV);
} }
auto Name = getCanonicalText(Decl->Name); auto Name = Decl->Name->getCanonicalText();
auto Ty = createConType(Name); auto Ty = createConType(Name);
// Must be added early so we can create recursive types // Must be added early so we can create recursive types
@ -373,7 +375,7 @@ void Checker::forwardDeclare(Node* X) {
for (auto Field: Decl->Fields) { for (auto Field: Decl->Fields) {
FieldsTy = new Type( FieldsTy = new Type(
TField( TField(
getCanonicalText(Field->Name), Field->Name->getCanonicalText(),
new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))), new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))),
FieldsTy FieldsTy
) )
@ -435,13 +437,11 @@ void Checker::initialize(Node* N) {
Contexts.pop(); Contexts.pop();
} }
void visitLetDeclaration(LetDeclaration* Let) { void visitFunctionDeclaration(FunctionDeclaration* Func) {
if (Let->isFunction()) { Func->Ctx = createDerivedContext();
Let->Ctx = createDerivedContext(); Contexts.push(Func->Ctx);
Contexts.push(Let->Ctx); visitEachChild(Func);
visitEachChild(Let); Contexts.pop();
Contexts.pop();
}
} }
// void visitVariableDeclaration(VariableDeclaration* Var) { // void visitVariableDeclaration(VariableDeclaration* Var) {
@ -456,22 +456,18 @@ void Checker::initialize(Node* N) {
} }
void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { void Checker::forwardDeclareFunctionDeclaration(FunctionDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) {
if (!Let->isFunction()) {
return;
}
// std::cerr << "declare " << Let->getNameAsString() << std::endl; // std::cerr << "declare " << Let->getNameAsString() << std::endl;
setContext(Let->Ctx); setContext(Let->Ctx);
auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) {
auto Id = getCanonicalText(Class->Name); auto Id = Class->Name->getCanonicalText();
auto Ctx = &getContext(); auto Ctx = &getContext();
std::vector<Type*> Out; std::vector<Type*> Out;
for (auto TE: Class->TypeVars) { for (auto TE: Class->TypeVars) {
auto Name = getCanonicalText(TE->Name); auto Name = TE->Name->getCanonicalText();
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
TV->asVar().Context.emplace(Id); TV->asVar().Context.emplace(Id);
Ctx->Env.add(Name, new Forall(TV), SymKind::Type); Ctx->Env.add(Name, new Forall(TV), SymKind::Type);
@ -493,8 +489,8 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs,
// Otherwise, the type is not further specified and we create a new // Otherwise, the type is not further specified and we create a new
// unification variable. // unification variable.
Type* Ty; Type* Ty;
if (Let->TypeAssert) { if (Let->hasTypeAssert()) {
Ty = inferTypeExpression(Let->TypeAssert->TypeExpression); Ty = inferTypeExpression(Let->getTypeAssert()->TypeExpression);
} else { } else {
Ty = createTypeVar(); Ty = createTypeVar();
} }
@ -507,9 +503,33 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs,
if (Let->isInstance()) { if (Let->isInstance()) {
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent); auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
auto Class = cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, getCanonicalText(Instance->Name) }, SymbolKind::Class)); auto Class = cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
// TODO check if `Class` is nullptr
auto SigLet = cast<LetDeclaration>(Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var)); if (Class == nullptr) {
// TODO print diagnostic
// DE.add<TypeclassNotFoundDiagnostic>(Instance->Name->getCanonicalText());
goto after_isinstance;
}
auto Decl = Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var);
if (Decl == nullptr) {
// TODO print diagnostic
// DE.add<MethodNotFoundInTypeclass>(Let->getNameAsStrings), Let->getName());
goto after_isinstance;
}
if (!isa<VariableDeclaration>(Decl)) {
// TODO print diagnostic
// DE.add<MustBeVariableDeclaration>(Decl);
goto after_isinstance;
}
auto FuncDecl = cast<FunctionDeclaration>(Decl);
auto Params = addClassVars(Class, false); auto Params = addClassVars(Class, false);
@ -536,23 +556,25 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs,
// It would be very strange if there was no type assert in the type // It would be very strange if there was no type assert in the type
// class let-declaration but we rather not let the compiler crash if that happens. // class let-declaration but we rather not let the compiler crash if that happens.
if (SigLet->TypeAssert) { if (FuncDecl->hasTypeAssert()) {
// Note that we can't do SigLet->TypeAssert->TypeExpression->getType() // Note that we can't do SigLet->TypeAssert->TypeExpression->getType()
// because we need to re-generate the type within the local context of // because we need to re-generate the type within the local context of
// this let-declaration. // this let-declaration.
// TODO make CEqual accept multiple nodes // TODO make CEqual accept multiple nodes
makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let); makeEqual(Ty, inferTypeExpression(FuncDecl->getTypeAssert()->TypeExpression), Let);
} }
} }
if (Let->Body) { after_isinstance:
switch (Let->Body->getKind()) {
if (Let->hasBody()) {
switch (Let->getBody()->getKind()) {
case NodeKind::LetExprBody: case NodeKind::LetExprBody:
break; break;
case NodeKind::LetBlockBody: case NodeKind::LetBlockBody:
{ {
auto Block = static_cast<LetBlockBody*>(Let->Body); auto Block = static_cast<LetBlockBody*>(Let->getBody());
Let->Ctx->ReturnType = createTypeVar(); Let->Ctx->ReturnType = createTypeVar();
for (auto Element: Block->Elements) { for (auto Element: Block->Elements) {
forwardDeclare(Element); forwardDeclare(Element);
@ -570,11 +592,7 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs,
} }
void Checker::inferFunctionDeclaration(LetDeclaration* Decl) { void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) {
if (!Decl->isFunction()) {
return;
}
// std::cerr << "infer " << Decl->getNameAsString() << std::endl; // std::cerr << "infer " << Decl->getNameAsString() << std::endl;
@ -584,21 +602,21 @@ void Checker::inferFunctionDeclaration(LetDeclaration* Decl) {
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
Type* RetType; Type* RetType;
for (auto Param: Decl->Params) { for (auto Param: Decl->getParams()) {
ParamTypes.push_back(inferPattern(Param->Pattern)); ParamTypes.push_back(inferPattern(Param->Pattern));
} }
if (Decl->Body) { if (Decl->hasBody()) {
switch (Decl->Body->getKind()) { switch (Decl->getBody()->getKind()) {
case NodeKind::LetExprBody: case NodeKind::LetExprBody:
{ {
auto Expr = static_cast<LetExprBody*>(Decl->Body); auto Expr = static_cast<LetExprBody*>(Decl->getBody());
RetType = inferExpression(Expr->Expression); RetType = inferExpression(Expr->Expression);
break; break;
} }
case NodeKind::LetBlockBody: case NodeKind::LetBlockBody:
{ {
auto Block = static_cast<LetBlockBody*>(Decl->Body); auto Block = static_cast<LetBlockBody*>(Decl->getBody());
RetType = Decl->Ctx->ReturnType; RetType = Decl->Ctx->ReturnType;
for (auto Element: Block->Elements) { for (auto Element: Block->Elements) {
infer(Element); infer(Element);
@ -680,29 +698,34 @@ void Checker::infer(Node* N) {
break; break;
} }
case NodeKind::LetDeclaration: case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration:
case NodeKind::NamedFunctionDeclaration:
{ {
// Function declarations are handled separately in inferFunctionDeclaration() auto Decl = static_cast<FunctionDeclaration*>(N);
auto Decl = static_cast<LetDeclaration*>(N);
if (Decl->Visited) { if (Decl->Visited) {
break; break;
} }
if (Decl->isFunction()) { Decl->IsCycleActive = true;
Decl->IsCycleActive = true; Decl->Visited = true;
Decl->Visited = true; inferFunctionDeclaration(Decl);
inferFunctionDeclaration(Decl); Decl->IsCycleActive = false;
Decl->IsCycleActive = false; break;
} else if (Decl->isVariable()) { }
auto Ty = Decl->getType();
if (Decl->Body) { case NodeKind::VariableDeclaration:
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); {
auto E = static_cast<LetExprBody*>(Decl->Body); auto Decl = static_cast<VariableDeclaration*>(N);
auto Ty2 = inferExpression(E->Expression); auto Ty = Decl->getType();
makeEqual(Ty, Ty2, Decl); if (Decl->Body) {
} ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
auto Ty3 = inferPattern(Decl->Pattern); auto E = static_cast<LetExprBody*>(Decl->Body);
makeEqual(Ty, Ty3, Decl); auto Ty2 = inferExpression(E->Expression);
makeEqual(Ty, Ty2, Decl);
} }
auto Ty3 = inferPattern(Decl->Pattern);
makeEqual(Ty, Ty3, Decl);
break; break;
} }
@ -801,7 +824,7 @@ void Checker::inferConstraintExpression(ConstraintExpression* C) {
std::vector<Type*> Types; std::vector<Type*> Types;
for (auto TE: D->TEs) { for (auto TE: D->TEs) {
auto Ty = inferTypeExpression(TE); auto Ty = inferTypeExpression(TE);
Ty->asVar().Provided->emplace(getCanonicalText(D->Name)); Ty->asVar().Provided->emplace(D->Name->getCanonicalText());
Types.push_back(Ty); Types.push_back(Ty);
} }
break; break;
@ -824,10 +847,10 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
case NodeKind::ReferenceTypeExpression: case NodeKind::ReferenceTypeExpression:
{ {
auto RefTE = static_cast<ReferenceTypeExpression*>(N); auto RefTE = static_cast<ReferenceTypeExpression*>(N);
auto Scm = lookup(getCanonicalText(RefTE->Name), SymKind::Type); auto Scm = lookup(RefTE->Name->getCanonicalText(), SymKind::Type);
Type* Ty; Type* Ty;
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(RefTE->Name), RefTE->Name); DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name);
Ty = createTypeVar(); Ty = createTypeVar();
} else { } else {
Ty = instantiate(Scm, RefTE); Ty = instantiate(Scm, RefTE);
@ -850,13 +873,13 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
case NodeKind::VarTypeExpression: case NodeKind::VarTypeExpression:
{ {
auto VarTE = static_cast<VarTypeExpression*>(N); auto VarTE = static_cast<VarTypeExpression*>(N);
auto Ty = lookupMono(getCanonicalText(VarTE->Name), SymKind::Type); auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type);
if (Ty == nullptr) { if (Ty == nullptr) {
if (!AutoVars || Config.typeVarsRequireForall()) { if (!AutoVars || Config.typeVarsRequireForall()) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(VarTE->Name), VarTE->Name); DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
} }
Ty = createRigidVar(getCanonicalText(VarTE->Name)); Ty = createRigidVar(VarTE->Name->getCanonicalText());
addBinding(getCanonicalText(VarTE->Name), new Forall(Ty), SymKind::Type); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
} }
ZEN_ASSERT(Ty->isVar()); ZEN_ASSERT(Ty->isVar());
N->setType(Ty); N->setType(Ty);
@ -868,7 +891,7 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
auto RecTE = static_cast<RecordTypeExpression*>(N); auto RecTE = static_cast<RecordTypeExpression*>(N);
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil()); auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil());
for (auto [Field, Comma]: RecTE->Fields) { for (auto [Field, Comma]: RecTE->Fields) {
Ty = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty)); Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty));
} }
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
@ -980,7 +1003,7 @@ Type* Checker::inferExpression(Expression* X) {
Ty = new Type(TNil()); Ty = new Type(TNil());
for (auto [Field, Comma]: Record->Fields) { for (auto [Field, Comma]: Record->Fields) {
Ty = new Type(TField( Ty = new Type(TField(
getCanonicalText(Field->Name), Field->Name->getCanonicalText(),
new Type(TPresent(inferExpression(Field->getExpression()))), new Type(TPresent(inferExpression(Field->getExpression()))),
Ty Ty
)); ));
@ -999,11 +1022,12 @@ Type* Checker::inferExpression(Expression* X) {
case NodeKind::ReferenceExpression: case NodeKind::ReferenceExpression:
{ {
auto Ref = static_cast<ReferenceExpression*>(X); auto Ref = static_cast<ReferenceExpression*>(X);
auto Name = Ref->Name.getCanonicalText();
ZEN_ASSERT(Ref->ModulePath.empty()); ZEN_ASSERT(Ref->ModulePath.empty());
if (Ref->Name->is<IdentifierAlt>()) { if (Ref->Name.isIdentifierAlt()) {
auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); auto Scm = lookup(Name, SymKind::Var);
if (!Scm) { if (!Scm) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(Ref->Name), Ref->Name); DE.add<BindingNotFoundDiagnostic>(Name, Ref->Name);
Ty = createTypeVar(); Ty = createTypeVar();
break; break;
} }
@ -1012,12 +1036,12 @@ Type* Checker::inferExpression(Expression* X) {
} }
auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); auto Target = Ref->getScope()->lookup(Ref->getSymbolPath());
if (!Target) { if (!Target) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(Ref->Name), Ref->Name); DE.add<BindingNotFoundDiagnostic>(Name, Ref->Name);
Ty = createTypeVar(); Ty = createTypeVar();
break; break;
} }
if (Target->getKind() == NodeKind::LetDeclaration) { if (isa<FunctionDeclaration>(Target)) {
auto Let = static_cast<LetDeclaration*>(Target); auto Let = static_cast<FunctionDeclaration*>(Target);
if (Let->IsCycleActive) { if (Let->IsCycleActive) {
Ty = Let->getType(); Ty = Let->getType();
break; break;
@ -1026,7 +1050,7 @@ Type* Checker::inferExpression(Expression* X) {
infer(Let); infer(Let);
} }
} }
auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); auto Scm = lookup(Name, SymKind::Var);
ZEN_ASSERT(Scm); ZEN_ASSERT(Scm);
Ty = instantiate(Scm, X); Ty = instantiate(Scm, X);
break; break;
@ -1048,9 +1072,9 @@ Type* Checker::inferExpression(Expression* X) {
case NodeKind::InfixExpression: case NodeKind::InfixExpression:
{ {
auto Infix = static_cast<InfixExpression*>(X); auto Infix = static_cast<InfixExpression*>(X);
auto Scm = lookup(Infix->Operator->getText(), SymKind::Var); auto Scm = lookup(Infix->Operator.getCanonicalText(), SymKind::Var);
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator); DE.add<BindingNotFoundDiagnostic>(Infix->Operator.getCanonicalText(), Infix->Operator);
Ty = createTypeVar(); Ty = createTypeVar();
break; break;
} }
@ -1091,7 +1115,7 @@ Type* Checker::inferExpression(Expression* X) {
auto K = static_cast<Identifier*>(Member->Name); auto K = static_cast<Identifier*>(Member->Name);
Ty = createTypeVar(); Ty = createTypeVar();
auto RestTy = createTypeVar(); auto RestTy = createTypeVar();
makeEqual(new Type(TField(getCanonicalText(K), Ty, RestTy)), ExprTy, Member); makeEqual(new Type(TField(K->getCanonicalText(), Ty, RestTy)), ExprTy, Member);
break; break;
} }
default: default:
@ -1138,20 +1162,20 @@ Type* Checker::inferPattern(
{ {
auto P = static_cast<BindPattern*>(Pattern); auto P = static_cast<BindPattern*>(Pattern);
auto Ty = createTypeVar(); auto Ty = createTypeVar();
addBinding(getCanonicalText(P->Name), new Forall(TVs, Constraints, Ty), SymKind::Var); addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty), SymKind::Var);
return Ty; return Ty;
} }
case NodeKind::NamedTuplePattern: case NodeKind::NamedTuplePattern:
{ {
auto P = static_cast<NamedTuplePattern*>(Pattern); auto P = static_cast<NamedTuplePattern*>(Pattern);
auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
std::vector<Type*> ElementTypes; std::vector<Type*> ElementTypes;
for (auto P2: P->Patterns) { for (auto P2: P->Patterns) {
ElementTypes.push_back(inferPattern(P2, Constraints, TVs)); ElementTypes.push_back(inferPattern(P2, Constraints, TVs));
} }
if (!Scm) { if (!Scm) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(P->Name), P->Name); DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
return createTypeVar(); return createTypeVar();
} }
auto Ty = instantiate(Scm, P); auto Ty = instantiate(Scm, P);
@ -1181,9 +1205,9 @@ Type* Checker::inferPattern(
FieldTy = inferPattern(Field->Pattern, Constraints, TVs); FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
} else { } else {
FieldTy = createTypeVar(); FieldTy = createTypeVar();
addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
} }
RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy));
} }
return RecordTy; return RecordTy;
} }
@ -1191,9 +1215,9 @@ Type* Checker::inferPattern(
case NodeKind::NamedRecordPattern: case NodeKind::NamedRecordPattern:
{ {
auto P = static_cast<NamedRecordPattern*>(Pattern); auto P = static_cast<NamedRecordPattern*>(Pattern);
auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
if (Scm == nullptr) { if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(P->Name), P->Name); DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
return createTypeVar(); return createTypeVar();
} }
auto RestField = getRestField(P->Fields); auto RestField = getRestField(P->Fields);
@ -1214,9 +1238,9 @@ Type* Checker::inferPattern(
FieldTy = inferPattern(Field->Pattern, Constraints, TVs); FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
} else { } else {
FieldTy = createTypeVar(); FieldTy = createTypeVar();
addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
} }
RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy));
} }
auto Ty = instantiate(Scm, P); auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar(); auto RetTy = createTypeVar();
@ -1287,7 +1311,14 @@ void Checker::populate(SourceFile* SF) {
std::stack<Node*> Stack; std::stack<Node*> Stack;
void visitLetDeclaration(LetDeclaration* N) { void visitFunctionDeclaration(FunctionDeclaration* N) {
RefGraph.addVertex(N);
Stack.push(N);
visitEachChild(N);
Stack.pop();
}
void visitVariableDeclaration(VariableDeclaration* N) {
RefGraph.addVertex(N); RefGraph.addVertex(N);
Stack.push(N); Stack.push(N);
visitEachChild(N); visitEachChild(N);
@ -1295,22 +1326,26 @@ void Checker::populate(SourceFile* SF) {
} }
void visitReferenceExpression(ReferenceExpression* N) { void visitReferenceExpression(ReferenceExpression* N) {
auto Y = static_cast<ReferenceExpression*>(N); auto Ref = static_cast<ReferenceExpression*>(N);
auto Def = Y->getScope()->lookup(Y->getSymbolPath()); auto Def = Ref->getScope()->lookup(Ref->getSymbolPath());
// Name lookup failures will be reported directly in inferExpression(). if (Def == nullptr) {
if (Def == nullptr || Def->getKind() != NodeKind::LetDeclaration) { // Name lookup failures will be reported directly in inferExpression().
return; return;
} }
ZEN_ASSERT(isa<FunctionDeclaration>(Def) || isa<VariableDeclaration>(Def) || isa<Parameter>(Def));
// This case ensures that a deeply nested structure that references a // This case ensures that a deeply nested structure that references a
// parameter of a parent node but is not referenced itself is correctly handled. // parameter of a parent node but is not referenced itself is correctly handled.
// Note that the edge goes from the parent let to the parameter. This is normal. // Note that the edge goes from the parent let to the parameter. This is normal.
if (Def->getKind() == NodeKind::Parameter) { // if (Def->getKind() == NodeKind::Parameter) {
RefGraph.addEdge(Stack.top(), Def->Parent); // RefGraph.addEdge(Stack.top(), Def->Parent);
// return;
// }
if (Stack.empty()) {
// An empty stack means we are traversing the toplevel of the source
// file, in which case we don't have anyting to connect with.
return; return;
} }
if (!Stack.empty()) { RefGraph.addEdge(Def, Stack.top());
RefGraph.addEdge(Def, Stack.top());
}
} }
}; };
@ -1353,10 +1388,10 @@ void Checker::check(SourceFile *SF) {
auto TVs = new TVSet; auto TVs = new TVSet;
auto Constraints = new ConstraintSet; auto Constraints = new ConstraintSet;
for (auto N: Nodes) { for (auto N: Nodes) {
if (N->getKind() != NodeKind::LetDeclaration) { if (!isa<FunctionDeclaration>(N)) {
continue; continue;
} }
auto Decl = static_cast<LetDeclaration*>(N); auto Decl = static_cast<FunctionDeclaration*>(N);
forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); forwardDeclareFunctionDeclaration(Decl, TVs, Constraints);
} }
} }

View file

@ -155,7 +155,12 @@ static std::string describe(NodeKind Type) {
return "'class'"; return "'class'";
case NodeKind::InstanceKeyword: case NodeKind::InstanceKeyword:
return "'instance'"; return "'instance'";
case NodeKind::LetDeclaration: case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration:
case NodeKind::NamedFunctionDeclaration:
return "a let-declaration";
case NodeKind::VariableDeclaration:
return "a let-declaration"; return "a let-declaration";
case NodeKind::CallExpression: case NodeKind::CallExpression:
return "a call-expression"; return "a call-expression";

View file

@ -11,7 +11,7 @@ Value Evaluator::evaluateExpression(Expression* X, Env& Env) {
case NodeKind::ReferenceExpression: case NodeKind::ReferenceExpression:
{ {
auto RE = static_cast<ReferenceExpression*>(X); auto RE = static_cast<ReferenceExpression*>(X);
return Env.lookup(getCanonicalText(RE->Name)); return Env.lookup(RE->Name.getCanonicalText());
// auto Decl = RE->getScope()->lookup(RE->getSymbolPath()); // auto Decl = RE->getScope()->lookup(RE->getSymbolPath());
// ZEN_ASSERT(Decl && Decl->getKind() == NodeKind::FunctionDeclaration); // ZEN_ASSERT(Decl && Decl->getKind() == NodeKind::FunctionDeclaration);
// return static_cast<FunctionDeclaration*>(Decl); // return static_cast<FunctionDeclaration*>(Decl);
@ -48,7 +48,7 @@ void Evaluator::assignPattern(Pattern* P, Value& V, Env& E) {
case NodeKind::BindPattern: case NodeKind::BindPattern:
{ {
auto BP = static_cast<BindPattern*>(P); auto BP = static_cast<BindPattern*>(P);
E.add(getCanonicalText(BP->Name), V); E.add(BP->Name->getCanonicalText(), V);
break; break;
} }
default: default:
@ -62,12 +62,12 @@ Value Evaluator::apply(Value Op, std::vector<Value> Args) {
{ {
auto Fn = Op.getDeclaration(); auto Fn = Op.getDeclaration();
Env NewEnv; Env NewEnv;
for (auto [Param, Arg]: zen::zip(Fn->Params, Args)) { for (auto [Param, Arg]: zen::zip(Fn->getParams(), Args)) {
assignPattern(Param->Pattern, Arg, NewEnv); assignPattern(Param->Pattern, Arg, NewEnv);
} }
switch (Fn->Body->getKind()) { switch (Fn->getBody()->getKind()) {
case NodeKind::LetExprBody: case NodeKind::LetExprBody:
return evaluateExpression(static_cast<LetExprBody*>(Fn->Body)->Expression, NewEnv); return evaluateExpression(static_cast<LetExprBody*>(Fn->getBody())->Expression, NewEnv);
default: default:
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
@ -98,23 +98,28 @@ void Evaluator::evaluate(Node* N, Env& E) {
evaluateExpression(ES->Expression, E); evaluateExpression(ES->Expression, E);
break; break;
} }
case NodeKind::LetDeclaration: case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration:
case NodeKind::NamedFunctionDeclaration:
{ {
auto Decl = static_cast<LetDeclaration*>(N); auto Decl = static_cast<FunctionDeclaration*>(N);
if (Decl->isFunction()) { E.add(Decl->getNameAsString(), Decl);
E.add(Decl->getNameAsString(), Decl); break;
} else { }
Value V; case NodeKind::VariableDeclaration:
if (Decl->Body) { {
switch (Decl->Body->getKind()) { auto Decl = static_cast<VariableDeclaration*>(N);
case NodeKind::LetExprBody: Value V;
{ if (Decl->Body) {
auto Body = static_cast<LetExprBody*>(Decl->Body); switch (Decl->Body->getKind()) {
V = evaluateExpression(Body->Expression, E); case NodeKind::LetExprBody:
} {
default: auto Body = static_cast<LetExprBody*>(Decl->Body);
ZEN_UNREACHABLE V = evaluateExpression(Body->Expression, E);
} }
default:
ZEN_UNREACHABLE
} }
} }
break; break;

View file

@ -32,16 +32,6 @@
namespace bolt { namespace bolt {
bool isOperator(Token* T) {
switch (T->getKind()) {
case NodeKind::VBar:
case NodeKind::CustomOperator:
return true;
default:
return false;
}
}
std::optional<OperatorInfo> OperatorTable::getInfix(Token* T) { std::optional<OperatorInfo> OperatorTable::getInfix(Token* T) {
auto Match = Mapping.find(T->getText()); auto Match = Mapping.find(T->getText());
if (Match == Mapping.end() || !Match->second.isInfix()) { if (Match == Mapping.end() || !Match->second.isInfix()) {
@ -828,7 +818,7 @@ Expression* Parser::parsePrimitiveExpression() {
DE.add<UnexpectedTokenDiagnostic>(File, T3, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt }); DE.add<UnexpectedTokenDiagnostic>(File, T3, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt });
return nullptr; return nullptr;
} }
return new ReferenceExpression(Annotations, ModulePath, static_cast<Symbol*>(T3)); return new ReferenceExpression(Annotations, ModulePath, Symbol::from_raw_node(T3));
} }
case NodeKind::LParen: case NodeKind::LParen:
{ {
@ -1025,7 +1015,7 @@ Expression* Parser::parseInfixOperatorAfterExpression(Expression* Left, int MinP
} }
Right = NewRight; Right = NewRight;
} }
Left = new InfixExpression(Left, T0, Right); Left = new InfixExpression(Left, Operator::from_raw_node(T0), Right);
} }
return Left; return Left;
} }
@ -1141,17 +1131,31 @@ IfStatement* Parser::parseIfStatement() {
return new IfStatement(Parts); return new IfStatement(Parts);
} }
LetDeclaration* Parser::parseLetDeclaration() { enum class LetMode {
Prefix,
Infix,
Suffix,
Wrapped,
VarOrNamed,
};
Node* Parser::parseLetDeclaration() {
auto Annotations = parseAnnotations(); auto Annotations = parseAnnotations();
PubKeyword* Pub = nullptr; PubKeyword* Pub = nullptr;
ForeignKeyword* Foreign = nullptr; ForeignKeyword* Foreign = nullptr;
LetKeyword* Let; LetKeyword* Let;
MutKeyword* Mut = nullptr; MutKeyword* Mut = nullptr;
Operator Op;
Symbol Sym;
Pattern* Name; Pattern* Name;
Parameter* Param;
Parameter* Left;
Parameter* Right;
std::vector<Parameter*> Params; std::vector<Parameter*> Params;
TypeAssert* TA = nullptr; TypeAssert* TA = nullptr;
LetBody* Body = nullptr; LetBody* Body = nullptr;
LetMode Mode;
auto T0 = Tokens.get(); auto T0 = Tokens.get();
if (T0->getKind() == NodeKind::PubKeyword) { if (T0->getKind() == NodeKind::PubKeyword) {
@ -1183,38 +1187,46 @@ LetDeclaration* Parser::parseLetDeclaration() {
auto T2 = Tokens.peek(0); auto T2 = Tokens.peek(0);
auto T3 = Tokens.peek(1); auto T3 = Tokens.peek(1);
auto T4 = Tokens.peek(2); auto T4 = Tokens.peek(2);
if (isOperator(T2)) { if (isa<Operator>(T2)) {
// Prefix function declaration
Tokens.get(); Tokens.get();
auto P1 = parseNarrowPattern(); auto P1 = parseNarrowPattern();
Params.push_back(new Parameter(P1, nullptr)); Param = new Parameter(P1, nullptr);
Name = new BindPattern(T2); Op = Operator::from_raw_node(T2);
Mode = LetMode::Prefix;
goto after_params; goto after_params;
} else if (isOperator(T3) && (T4->getKind() == NodeKind::Colon || T4->getKind() == NodeKind::Equals || T4->getKind() == NodeKind::BlockStart || T4->getKind() == NodeKind::LineFoldEnd)) { } else if (isa<Operator>(T3) && (T4->getKind() == NodeKind::Colon || T4->getKind() == NodeKind::Equals || T4->getKind() == NodeKind::BlockStart || T4->getKind() == NodeKind::LineFoldEnd)) {
// Sufffix function declaration
auto P1 = parseNarrowPattern(); auto P1 = parseNarrowPattern();
Params.push_back(new Parameter(P1, nullptr)); Param = new Parameter(P1, nullptr);
Tokens.get(); Tokens.get();
Name = new BindPattern(T3); Op = Operator::from_raw_node(T3);
Mode = LetMode::Suffix;
goto after_params; goto after_params;
} else if (T2->getKind() == NodeKind::LParen && isOperator(T3) && T4->getKind() == NodeKind::RParen) { } else if (T2->getKind() == NodeKind::LParen && isa<Operator>(T3) && T4->getKind() == NodeKind::RParen) {
// Wrapped operator function declaration
Tokens.get(); Tokens.get();
Tokens.get(); Tokens.get();
Tokens.get(); Tokens.get();
Name = new BindPattern( Sym = new WrappedOperator(
new WrappedOperator( static_cast<class LParen*>(T2),
static_cast<class LParen*>(T2), Operator::from_raw_node(T3),
T3, static_cast<class RParen*>(T3)
static_cast<class RParen*>(T3)
)
); );
} else if (isOperator(T3)) { Mode = LetMode::Wrapped;
} else if (isa<Operator>(T3)) {
// Infix function declaration
auto P1 = parseNarrowPattern(); auto P1 = parseNarrowPattern();
Params.push_back(new Parameter(P1, nullptr)); Left = new Parameter(P1, nullptr);
Tokens.get(); Tokens.get();
auto P2 = parseNarrowPattern(); auto P2 = parseNarrowPattern();
Params.push_back(new Parameter(P2, nullptr)); Right = new Parameter(P2, nullptr);
Name = new BindPattern(T3); Op = Operator::from_raw_node(T3);
Mode = LetMode::Infix;
goto after_params; goto after_params;
} else { } else {
// Variable declaration or named function declaration
Mode = LetMode::VarOrNamed;
Name = parseNarrowPattern(); Name = parseNarrowPattern();
if (!Name) { if (!Name) {
if (Pub) { if (Pub) {
@ -1313,17 +1325,77 @@ after_params:
finish: finish:
return new LetDeclaration( switch (Mode) {
Annotations, case LetMode::Prefix:
Pub, return new PrefixFunctionDeclaration(
Foreign, Annotations,
Let, Pub,
Mut, Foreign,
Name, Let,
Params, Op,
TA, Param,
Body TA,
); Body
);
case LetMode::Suffix:
return new SuffixFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Param,
Op,
TA,
Body
);
case LetMode::Infix:
return new InfixFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Left,
Op,
Right,
TA,
Body
);
case LetMode::Wrapped:
return new NamedFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Sym,
Params,
TA,
Body
);
case LetMode::VarOrNamed:
if (Name->getKind() != NodeKind::BindPattern || Mut) {
// TODO assert Params is empty
return new VariableDeclaration(
Annotations,
Pub,
Foreign,
Let,
Mut,
Name,
TA,
Body
);
}
return new NamedFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Name->as<BindPattern>()->Name,
Params,
TA,
Body
);
}
} }
Node* Parser::parseLetBodyElement() { Node* Parser::parseLetBodyElement() {