Split LetDeclaration into VariableDeclaration and FunctionDeclaration

We only generate VariableDeclaration when we're absolutely sure it is a
variable.
This commit is contained in:
Sam Vervaeck 2024-03-09 12:51:35 +01:00
parent 2bfd88b05f
commit 719dbfcad4
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
12 changed files with 1156 additions and 564 deletions

View file

@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.10)
project(Bolt CXX)
project(Bolt C CXX)
set(CMAKE_CXX_STANDARD 20)
@ -17,6 +17,8 @@ if (CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo" OR CMAKE_BUILD_TYPE STREQUAL "Deb
set(BOLT_DEBUG ON)
endif()
find_package(LLVM 18.1.0 REQUIRED)
add_library(
BoltCore
#src/Text.cc
@ -28,6 +30,7 @@ add_library(
src/Types.cc
src/Checker.cc
src/Evaluator.cc
src/Scope.cc
)
target_link_directories(
BoltCore
@ -61,6 +64,22 @@ target_link_libraries(
icuuc
)
add_library(
BoltLLVM
src/LLVMCodeGen.cc
)
llvm_map_components_to_libnames(llvm_libs support core irreader)
target_include_directories(BoltLLVM PRIVATE ${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
target_compile_definitions(BoltLLVM PRIVATE ${LLVM_DEFINITIONS_LIST})
target_link_libraries(
BoltLLVM
PUBLIC
BoltCore
${llvm_libs}
)
add_executable(
bolt
src/main.cc
@ -69,6 +88,7 @@ target_link_libraries(
bolt
PUBLIC
BoltCore
BoltLLVM
)
if (BOLT_ENABLE_TESTS)

View file

@ -1,11 +1,13 @@
#ifndef BOLT_CST_HPP
#define BOLT_CST_HPP
#include <cstdlib>
#include <limits>
#include <unordered_map>
#include <variant>
#include <vector>
#include "bolt/Common.hpp"
#include "zen/config.hpp"
#include "bolt/Integer.hpp"
@ -172,7 +174,11 @@ enum class NodeKind {
Parameter,
LetBlockBody,
LetExprBody,
LetDeclaration,
PrefixFunctionDeclaration,
InfixFunctionDeclaration,
SuffixFunctionDeclaration,
NamedFunctionDeclaration,
VariableDeclaration,
RecordDeclarationField,
RecordDeclaration,
VariantDeclaration,
@ -364,31 +370,6 @@ public:
};
/// Any node that can be used as an operator
///
/// This includes the following nodes:
/// - VBar
/// - CustomOperator
using Operator = Token;
/// Any node that can be used as a kind of identifier.
///
/// This includes the following nodes:
/// - Identifier
/// - IdentifierAlt
/// - WrappedOperator
using Symbol = Node;
inline bool isSymbol(const Node* N) {
return N->getKind() == NodeKind::Identifier
|| N->getKind() == NodeKind::IdentifierAlt
|| N->getKind() == NodeKind::WrappedOperator;
}
/// Get the text that is actually represented by a symbol, without all the
/// syntactic sugar.
ByteString getCanonicalText(const Symbol* N);
class Equals : public Token {
public:
@ -903,6 +884,8 @@ public:
std::string getText() const override;
std::string getCanonicalText() const;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::CustomOperator;
}
@ -935,6 +918,8 @@ public:
std::string getText() const override;
ByteString getCanonicalText() const;
bool isTypeVar() const;
static bool classof(const Node* N) {
@ -953,6 +938,8 @@ public:
std::string getText() const override;
ByteString getCanonicalText() const;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::IdentifierAlt;
}
@ -1021,6 +1008,182 @@ public:
};
/// Base node for things that can be used as an operator
///
/// This includes the following nodes:
/// - VBar
/// - CustomOperator
class Operator {
Node* N;
Operator(Node* N):
N(N) {}
public:
Operator() {}
Operator(VBar* N):
N(N) {}
Operator(CustomOperator* N):
N(N) {}
static Operator from_raw_node(Node* N) {
ZEN_ASSERT(isa<Operator>(N));
return N;
}
inline NodeKind getKind() const {
return N->getKind();
}
inline bool isVBar() const {
return N->getKind() == NodeKind::VBar;
}
inline bool isCustomOperator() const {
return N->getKind() == NodeKind::CustomOperator;
}
VBar* asVBar() const {
return static_cast<VBar*>(N);
}
CustomOperator* asCustomOperator() const {
return static_cast<CustomOperator*>(N);
}
operator Node*() const {
return N;
}
/// Get the name that is actually represented by an operator, without all the
/// syntactic sugar.
virtual ByteString getCanonicalText() const;
Token* getFirstToken() const;
Token* getLastToken() const;
inline static bool classof(const Node* N) {
return N->getKind() == NodeKind::VBar
|| N->getKind() == NodeKind::CustomOperator;
}
};
class WrappedOperator : public Node {
public:
class LParen* LParen;
Operator Op;
class RParen* RParen;
WrappedOperator(
class LParen* LParen,
Operator Operator,
class RParen* RParen
): Node(NodeKind::WrappedOperator),
LParen(LParen),
Op(Operator),
RParen(RParen) {}
inline Operator getOperator() const {
return Op;
}
ByteString getCanonicalText() const {
return Op.getCanonicalText();
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::WrappedOperator;
}
};
/// Base node for things that can be used as a symbol
///
/// This includes the following nodes:
/// - WrappedOperator
/// - Identifier
/// - IdentifierAlt
class Symbol {
Node* N;
Symbol(Node* N):
N(N) {}
public:
Symbol() {}
Symbol(WrappedOperator* N):
N(N) {}
Symbol(Identifier* N):
N(N) {}
Symbol(IdentifierAlt* N):
N(N) {}
static Symbol from_raw_node(Node* N) {
ZEN_ASSERT(isa<Symbol>(N));
return N;
}
NodeKind getKind() const {
return N->getKind();
}
bool isWrappedOperator() const {
return N->getKind() == NodeKind::WrappedOperator;
}
bool isIdentifier() const {
return N->getKind() == NodeKind::Identifier;
}
bool isIdentifierAlt() const {
return N->getKind() == NodeKind::IdentifierAlt;
}
IdentifierAlt* asIdentifierAlt() const {
return cast<IdentifierAlt>(N);
}
Identifier* asIdentifier() const {
return cast<Identifier>(N);
}
WrappedOperator* asWrappedOperator() const {
return cast<WrappedOperator>(N);
}
operator Node*() const {
return N;
}
/// Get the name that is actually represented by a symbol, without all the
/// syntactic sugar.
ByteString getCanonicalText() const;
Token* getFirstToken() const;
Token* getLastToken() const;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::Identifier
|| N->getKind() == NodeKind::IdentifierAlt
|| N->getKind() == NodeKind::WrappedOperator;
}
};
class Annotation : public Node {
public:
@ -1353,31 +1516,6 @@ public:
};
class WrappedOperator : public Symbol {
public:
class LParen* LParen;
Token* Op;
class RParen* RParen;
WrappedOperator(
class LParen* LParen,
Token* Operator,
class RParen* RParen
): Symbol(NodeKind::WrappedOperator),
LParen(LParen),
Op(Operator),
RParen(RParen) {}
inline Token* getOperator() const {
return Op;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class Pattern : public Node {
protected:
@ -1390,10 +1528,10 @@ protected:
class BindPattern : public Pattern {
public:
Symbol* Name;
Identifier* Name;
BindPattern(
Symbol* Name
Identifier* Name
): Pattern(NodeKind::BindPattern),
Name(Name) {}
@ -1608,11 +1746,11 @@ class ReferenceExpression : public Expression {
public:
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath;
Symbol* Name;
Symbol Name;
inline ReferenceExpression(
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath,
Symbol* Name
Symbol Name
): Expression(NodeKind::ReferenceExpression),
ModulePath(ModulePath),
Name(Name) {}
@ -1620,13 +1758,13 @@ public:
inline ReferenceExpression(
std::vector<Annotation*> Annotations,
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath,
Symbol* Name
Symbol Name
): Expression(NodeKind::ReferenceExpression, Annotations),
ModulePath(ModulePath),
Name(Name) {}
inline ByteString getNameAsString() const noexcept {
return getCanonicalText(Name);
return Name.getCanonicalText();
}
Token* getFirstToken() const override;
@ -1864,12 +2002,12 @@ class InfixExpression : public Expression {
public:
Expression* Left;
Token* Operator;
Operator Operator;
Expression* Right;
inline InfixExpression(
Expression* Left,
Token* Operator,
class Operator Operator,
Expression* Right
): Expression(NodeKind::InfixExpression),
Left(Left),
@ -1879,7 +2017,7 @@ public:
inline InfixExpression(
std::vector<Annotation*> Annotations,
Expression* Left,
Token* Operator,
class Operator Operator,
Expression* Right
): Expression(NodeKind::InfixExpression, Annotations),
Left(Left),
@ -2110,6 +2248,10 @@ public:
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::Parameter;
}
};
class LetBody : public Node {
@ -2155,7 +2297,7 @@ public:
};
class LetDeclaration : public TypedNode, public AnnotationContainer {
class FunctionDeclaration : public TypedNode, public AnnotationContainer {
Scope* TheScope = nullptr;
@ -2165,63 +2307,34 @@ public:
bool Visited = false;
InferContext* Ctx;
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
class MutKeyword* MutKeyword;
class Pattern* Pattern;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
FunctionDeclaration(NodeKind Kind, std::vector<Annotation*> Annotations = {}):
TypedNode(Kind), AnnotationContainer(Annotations) {}
LetDeclaration(
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
class MutKeyword* MutKeyword,
class Pattern* Pattern,
std::vector<Parameter*> Params,
class TypeAssert* TypeAssert,
LetBody* Body
): TypedNode(NodeKind::LetDeclaration),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
MutKeyword(MutKeyword),
Pattern(Pattern),
Params(Params),
TypeAssert(TypeAssert),
Body(Body) {}
virtual bool isPublic() const = 0;
LetDeclaration(
std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
class MutKeyword* MutKeyword,
class Pattern* Pattern,
std::vector<Parameter*> Params,
class TypeAssert* TypeAssert,
LetBody* Body
): TypedNode(NodeKind::LetDeclaration),
AnnotationContainer(Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
MutKeyword(MutKeyword),
Pattern(Pattern),
Params(Params),
TypeAssert(TypeAssert),
Body(Body) {}
virtual bool isForeign() const = 0;
virtual ByteString getNameAsString() const = 0;
virtual std::vector<Parameter*> getParams() const = 0;
virtual TypeAssert* getTypeAssert() const = 0;
bool hasTypeAssert() const {
return getTypeAssert();
}
virtual LetBody* getBody() const = 0;
bool hasBody() const {
return getBody();
}
inline Scope* getScope() override {
if (isFunction()) {
if (TheScope == nullptr) {
TheScope = new Scope(this);
}
return TheScope;
if (TheScope == nullptr) {
TheScope = new Scope(this);
}
return Parent->getScope();
return TheScope;
}
bool isInstance() const noexcept {
@ -2232,33 +2345,310 @@ public:
return Parent->getKind() == NodeKind::ClassDeclaration;
}
bool isSignature() const noexcept {
return ForeignKeyword;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::PrefixFunctionDeclaration
|| N->getKind() == NodeKind::InfixFunctionDeclaration
|| N->getKind() == NodeKind::SuffixFunctionDeclaration
|| N->getKind() == NodeKind::NamedFunctionDeclaration;
}
bool isVariable() const noexcept {
// Variables in classes and instances are never possible, so we reflect this by excluding them here.
return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body);
};
class PrefixFunctionDeclaration : public FunctionDeclaration {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
class Operator Name;
Parameter* Param;
class TypeAssert* TypeAssert;
LetBody* Body;
PrefixFunctionDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
Operator Name,
Parameter* Param,
class TypeAssert* TypeAssert,
LetBody* Body
): FunctionDeclaration(NodeKind::PrefixFunctionDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
Name(Name),
Param(Param),
TypeAssert(TypeAssert),
Body(Body) {}
bool isPublic() const override {
return PubKeyword != nullptr;
}
bool isFunction() const noexcept {
return !isSignature() && !isVariable();
bool isForeign() const override {
return ForeignKeyword != nullptr;
}
Symbol* getName() const noexcept {
ZEN_ASSERT(Pattern->getKind() == NodeKind::BindPattern);
return static_cast<BindPattern*>(Pattern)->Name;
ByteString getNameAsString() const override {
return Name.getCanonicalText();
}
ByteString getNameAsString() const noexcept {
return getCanonicalText(getName());
std::vector<Parameter*> getParams() const override {
return { Param };
}
class TypeAssert* getTypeAssert() const override {
return TypeAssert;
}
LetBody* getBody() const override {
return Body;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::LetDeclaration;
return N->getKind() == NodeKind::PrefixFunctionDeclaration;
}
};
class SuffixFunctionDeclaration : public FunctionDeclaration {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
Parameter* Param;
class Operator Name;
class TypeAssert* TypeAssert;
LetBody* Body;
SuffixFunctionDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
Parameter* Param,
Operator Name,
class TypeAssert* TypeAssert,
LetBody* Body
): FunctionDeclaration(NodeKind::SuffixFunctionDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
Name(Name),
Param(Param),
TypeAssert(TypeAssert),
Body(Body) {}
bool isPublic() const override {
return PubKeyword != nullptr;
}
bool isForeign() const override {
return ForeignKeyword != nullptr;
}
ByteString getNameAsString() const override {
return Name.getCanonicalText();
}
std::vector<Parameter*> getParams() const override {
return { Param };
}
class TypeAssert* getTypeAssert() const override {
return TypeAssert;
}
LetBody* getBody() const override {
return Body;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::SuffixFunctionDeclaration;
}
};
class InfixFunctionDeclaration : public FunctionDeclaration {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
Parameter* Left;
class Operator Name;
Parameter* Right;
class TypeAssert* TypeAssert;
LetBody* Body;
InfixFunctionDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
Parameter* Left,
class Operator Name,
Parameter* Right,
class TypeAssert* TypeAssert,
LetBody* Body
): FunctionDeclaration(NodeKind::InfixFunctionDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
Left(Left),
Name(Name),
Right(Right),
TypeAssert(TypeAssert),
Body(Body) {}
bool isPublic() const override {
return PubKeyword != nullptr;
}
bool isForeign() const override {
return ForeignKeyword != nullptr;
}
ByteString getNameAsString() const override {
return Name.getCanonicalText();
}
std::vector<Parameter*> getParams() const override {
return { Left, Right };
}
class TypeAssert* getTypeAssert() const override {
return TypeAssert;
}
LetBody* getBody() const override {
return Body;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::InfixFunctionDeclaration;
}
};
class NamedFunctionDeclaration : public FunctionDeclaration {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
class Symbol Name;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
NamedFunctionDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
class Symbol Name,
std::vector<Parameter*> Params,
class TypeAssert* TypeAssert,
LetBody* Body
): FunctionDeclaration(NodeKind::NamedFunctionDeclaration, Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
Name(Name),
Params(Params),
TypeAssert(TypeAssert),
Body(Body) {}
bool isPublic() const override {
return PubKeyword != nullptr;
}
bool isForeign() const override {
return ForeignKeyword != nullptr;
}
ByteString getNameAsString() const override {
return Name.getCanonicalText();
}
std::vector<Parameter*> getParams() const override {
return Params;
}
class TypeAssert* getTypeAssert() const override {
return TypeAssert;
}
LetBody* getBody() const override {
return Body;
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::NamedFunctionDeclaration;
}
};
class VariableDeclaration : public TypedNode, public AnnotationContainer {
public:
class PubKeyword* PubKeyword;
class ForeignKeyword* ForeignKeyword;
class LetKeyword* LetKeyword;
class MutKeyword* MutKeyword;
class Pattern* Pattern;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
VariableDeclaration(
class std::vector<Annotation*> Annotations,
class PubKeyword* PubKeyword,
class ForeignKeyword* ForeignKeyword,
class LetKeyword* LetKeyword,
class MutKeyword* MutKeyword,
class Pattern* Pattern,
class TypeAssert* TypeAssert,
LetBody* Body
): TypedNode(NodeKind::VariableDeclaration),
AnnotationContainer(Annotations),
PubKeyword(PubKeyword),
ForeignKeyword(ForeignKeyword),
LetKeyword(LetKeyword),
MutKeyword(MutKeyword),
Pattern(Pattern),
TypeAssert(TypeAssert),
Body(Body) {}
Symbol getName() const noexcept {
ZEN_ASSERT(Pattern->getKind() == NodeKind::BindPattern);
return static_cast<BindPattern*>(Pattern)->Name;
}
ByteString getNameAsString() const noexcept {
return getName().getCanonicalText();
}
Token* getFirstToken() const override;
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::VariableDeclaration;
}
};
@ -2555,7 +2945,11 @@ template<> inline NodeKind getNodeType<TypeAssert>() { return NodeKind::TypeAsse
template<> inline NodeKind getNodeType<Parameter>() { return NodeKind::Parameter; }
template<> inline NodeKind getNodeType<LetBlockBody>() { return NodeKind::LetBlockBody; }
template<> inline NodeKind getNodeType<LetExprBody>() { return NodeKind::LetExprBody; }
template<> inline NodeKind getNodeType<LetDeclaration>() { return NodeKind::LetDeclaration; }
template<> inline NodeKind getNodeType<PrefixFunctionDeclaration>() { return NodeKind::PrefixFunctionDeclaration; }
template<> inline NodeKind getNodeType<InfixFunctionDeclaration>() { return NodeKind::InfixFunctionDeclaration; }
template<> inline NodeKind getNodeType<SuffixFunctionDeclaration>() { return NodeKind::SuffixFunctionDeclaration; }
template<> inline NodeKind getNodeType<NamedFunctionDeclaration()>() { return NodeKind::NamedFunctionDeclaration; }
template<> inline NodeKind getNodeType<VariableDeclaration>() { return NodeKind::VariableDeclaration; }
template<> inline NodeKind getNodeType<RecordDeclarationField>() { return NodeKind::RecordDeclarationField; }
template<> inline NodeKind getNodeType<RecordDeclaration>() { return NodeKind::RecordDeclaration; }
template<> inline NodeKind getNodeType<ClassDeclaration>() { return NodeKind::ClassDeclaration; }

View file

@ -15,8 +15,8 @@ public:
void visit(Node* N) {
#define BOLT_GEN_CASE(name) \
case NodeKind::name: \
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
case NodeKind::name: \
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
switch (N->getKind()) {
BOLT_GEN_CASE(VBar)
@ -104,7 +104,11 @@ public:
BOLT_GEN_CASE(Parameter)
BOLT_GEN_CASE(LetBlockBody)
BOLT_GEN_CASE(LetExprBody)
BOLT_GEN_CASE(LetDeclaration)
BOLT_GEN_CASE(PrefixFunctionDeclaration)
BOLT_GEN_CASE(InfixFunctionDeclaration)
BOLT_GEN_CASE(SuffixFunctionDeclaration)
BOLT_GEN_CASE(NamedFunctionDeclaration)
BOLT_GEN_CASE(VariableDeclaration)
BOLT_GEN_CASE(RecordDeclaration)
BOLT_GEN_CASE(RecordDeclarationField)
BOLT_GEN_CASE(VariantDeclaration)
@ -116,6 +120,35 @@ public:
}
}
void dispatchSymbol(const Symbol& S) {
switch (S.getKind()) {
case NodeKind::Identifier:
visit(S.asIdentifier());
break;
case NodeKind::IdentifierAlt:
visit(S.asIdentifierAlt());
break;
case NodeKind::WrappedOperator:
visit(S.asWrappedOperator());
break;
default:
ZEN_UNREACHABLE
}
}
void dispatchOperator(const Operator& O) {
switch (O.getKind()) {
case NodeKind::VBar:
visit(O.asVBar());
break;
case NodeKind::CustomOperator:
visit(O.asCustomOperator());
break;
default:
ZEN_UNREACHABLE
}
}
protected:
void visitNode(Node* N) {
@ -494,7 +527,27 @@ protected:
static_cast<D*>(this)->visitLetBody(N);
}
void visitLetDeclaration(LetDeclaration* N) {
void visitFunctionDeclaration(FunctionDeclaration* N) {
static_cast<D*>(this)->visitNode(N);
}
void visitPrefixFunctionDeclaration(PrefixFunctionDeclaration* N) {
static_cast<D*>(this)->visitFunctionDeclaration(N);
}
void visitInfixFunctionDeclaration(InfixFunctionDeclaration* N) {
static_cast<D*>(this)->visitFunctionDeclaration(N);
}
void visitSuffixFunctionDeclaration(SuffixFunctionDeclaration* N) {
static_cast<D*>(this)->visitFunctionDeclaration(N);
}
void visitNamedFunctionDeclaration(NamedFunctionDeclaration* N) {
static_cast<D*>(this)->visitFunctionDeclaration(N);
}
void visitVariableDeclaration(VariableDeclaration* N) {
static_cast<D*>(this)->visitNode(N);
}
@ -629,7 +682,11 @@ public:
BOLT_GEN_CHILD_CASE(Parameter)
BOLT_GEN_CHILD_CASE(LetBlockBody)
BOLT_GEN_CHILD_CASE(LetExprBody)
BOLT_GEN_CHILD_CASE(LetDeclaration)
BOLT_GEN_CHILD_CASE(PrefixFunctionDeclaration)
BOLT_GEN_CHILD_CASE(InfixFunctionDeclaration)
BOLT_GEN_CHILD_CASE(SuffixFunctionDeclaration)
BOLT_GEN_CHILD_CASE(NamedFunctionDeclaration)
BOLT_GEN_CHILD_CASE(VariableDeclaration)
BOLT_GEN_CHILD_CASE(RecordDeclaration)
BOLT_GEN_CHILD_CASE(RecordDeclarationField)
BOLT_GEN_CHILD_CASE(VariantDeclaration)
@ -642,6 +699,8 @@ public:
}
#define BOLT_VISIT(node) static_cast<D*>(this)->visit(node)
#define BOLT_VISIT_SYMBOL(node) static_cast<D*>(this)->dispatchSymbol(node)
#define BOLT_VISIT_OPERATOR(node) static_cast<D*>(this)->dispatchOperator(node)
void visitEachChild(VBar* N) {
}
@ -771,7 +830,7 @@ public:
void visitEachChild(WrappedOperator* N) {
BOLT_VISIT(N->LParen);
BOLT_VISIT(N->Op);
BOLT_VISIT_OPERATOR(N->Op);
BOLT_VISIT(N->RParen);
}
@ -972,7 +1031,7 @@ public:
BOLT_VISIT(Name);
BOLT_VISIT(Dot);
}
BOLT_VISIT(N->Name);
BOLT_VISIT_SYMBOL(N->Name);
}
void visitEachChild(MatchCase* N) {
@ -1049,7 +1108,7 @@ public:
BOLT_VISIT(A);
}
BOLT_VISIT(N->Left);
BOLT_VISIT(N->Operator);
BOLT_VISIT_OPERATOR(N->Operator);
BOLT_VISIT(N->Right);
}
@ -1140,7 +1199,7 @@ public:
BOLT_VISIT(N->Expression);
}
void visitEachChild(LetDeclaration* N) {
void visitEachChild(PrefixFunctionDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
@ -1151,6 +1210,96 @@ public:
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
BOLT_VISIT(N->Param);
BOLT_VISIT_OPERATOR(N->Name);
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(InfixFunctionDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
if (N->ForeignKeyword) {
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
BOLT_VISIT(N->Left);
BOLT_VISIT_OPERATOR(N->Name);
BOLT_VISIT(N->Right);
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(SuffixFunctionDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
if (N->ForeignKeyword) {
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
BOLT_VISIT_OPERATOR(N->Name);
BOLT_VISIT(N->Param);
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(NamedFunctionDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
if (N->ForeignKeyword) {
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
BOLT_VISIT_SYMBOL(N->Name);
for (auto Param: N->Params) {
BOLT_VISIT(Param);
}
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(VariableDeclaration* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
if (N->ForeignKeyword) {
BOLT_VISIT(N->ForeignKeyword);
}
BOLT_VISIT(N->LetKeyword);
if (N->MutKeyword) {
BOLT_VISIT(N->MutKeyword);
}
BOLT_VISIT(N->Pattern);
for (auto Param: N->Params) {
BOLT_VISIT(Param);

View file

@ -241,7 +241,7 @@ class Checker {
/// Type inference
void forwardDeclare(Node* Node);
void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
void forwardDeclareFunctionDeclaration(FunctionDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true);
@ -249,7 +249,7 @@ class Checker {
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);
void infer(Node* node);
void inferFunctionDeclaration(LetDeclaration* N);
void inferFunctionDeclaration(FunctionDeclaration* N);
void inferConstraintExpression(ConstraintExpression* C);
/// Factory methods

View file

@ -1,6 +1,8 @@
#pragma once
#include "zen/config.hpp"
namespace bolt {
class LanguageConfig {

View file

@ -29,7 +29,7 @@ class Value {
union {
ByteString S;
Integer I;
LetDeclaration* D;
FunctionDeclaration* D;
NativeFunction F;
Tuple T;
};
@ -45,7 +45,7 @@ public:
Value(Integer I):
Kind(ValueKind::Integer), I(I) {}
Value(LetDeclaration* D):
Value(FunctionDeclaration* D):
Kind(ValueKind::SourceFunction), D(D) {}
Value(NativeFunction F):
@ -67,7 +67,7 @@ public:
new (&I) Tuple(V.T);
break;
case ValueKind::SourceFunction:
new (&D) LetDeclaration*(V.D);
new (&D) FunctionDeclaration*(V.D);
break;
case ValueKind::NativeFunction:
new (&F) NativeFunction(V.F);
@ -90,7 +90,7 @@ public:
new (&I) Tuple(Other.T);
break;
case ValueKind::SourceFunction:
new (&D) LetDeclaration*(Other.D);
new (&D) FunctionDeclaration*(Other.D);
break;
case ValueKind::NativeFunction:
new (&F) NativeFunction(Other.F);
@ -112,7 +112,7 @@ public:
return S;
}
inline LetDeclaration* getDeclaration() {
inline FunctionDeclaration* getDeclaration() {
ZEN_ASSERT(Kind == ValueKind::SourceFunction);
return D;
}

View file

@ -130,7 +130,7 @@ public:
Node* parseLetBodyElement();
LetDeclaration* parseLetDeclaration();
Node* parseLetDeclaration();
Node* parseClassElement();

View file

@ -59,216 +59,6 @@ ByteString TextFile::getText() const {
return Text;
}
Scope::Scope(Node* Source):
Source(Source) {
scan(Source);
}
void Scope::addSymbol(ByteString Name, Node* Decl, SymbolKind Kind) {
Mapping.emplace(Name, std::make_tuple(Decl, Kind));
}
void Scope::scan(Node* X) {
switch (X->getKind()) {
case NodeKind::SourceFile:
{
auto File = static_cast<SourceFile*>(X);
for (auto Element: File->Elements) {
scanChild(Element);
}
break;
}
case NodeKind::MatchCase:
{
auto Case = static_cast<MatchCase*>(X);
visitPattern(Case->Pattern, Case);
break;
}
case NodeKind::LetDeclaration:
{
auto Decl = static_cast<LetDeclaration*>(X);
ZEN_ASSERT(Decl->isFunction());
for (auto Param: Decl->Params) {
visitPattern(Param->Pattern, Param);
}
if (Decl->Body) {
scanChild(Decl->Body);
}
break;
}
default:
ZEN_UNREACHABLE
}
}
void Scope::scanChild(Node* X) {
switch (X->getKind()) {
case NodeKind::LetExprBody:
case NodeKind::ExpressionStatement:
case NodeKind::IfStatement:
case NodeKind::ReturnStatement:
break;
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(X);
for (auto Element: Block->Elements) {
scanChild(Element);
}
break;
}
case NodeKind::InstanceDeclaration:
// We ignore let-declarations inside instance-declarations for now
break;
case NodeKind::ClassDeclaration:
{
auto Decl = static_cast<ClassDeclaration*>(X);
addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Class);
for (auto Element: Decl->Elements) {
scanChild(Element);
}
break;
}
case NodeKind::LetDeclaration:
{
auto Decl = static_cast<LetDeclaration*>(X);
// No matter if it is a function or a variable, by visiting the pattern
// we add all relevant bindings to the current scope.
visitPattern(Decl->Pattern, Decl);
break;
}
case NodeKind::RecordDeclaration:
{
auto Decl = static_cast<RecordDeclaration*>(X);
addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type);
break;
}
case NodeKind::VariantDeclaration:
{
auto Decl = static_cast<VariantDeclaration*>(X);
addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type);
for (auto Member: Decl->Members) {
switch (Member->getKind()) {
case NodeKind::TupleVariantDeclarationMember:
{
auto T = static_cast<TupleVariantDeclarationMember*>(Member);
addSymbol(getCanonicalText(T->Name), Decl, SymbolKind::Constructor);
break;
}
case NodeKind::RecordVariantDeclarationMember:
{
auto R = static_cast<RecordVariantDeclarationMember*>(Member);
addSymbol(getCanonicalText(R->Name), Decl, SymbolKind::Constructor);
break;
}
default:
ZEN_UNREACHABLE
}
}
break;
}
default:
ZEN_UNREACHABLE
}
}
void Scope::visitPattern(Pattern* X, Node* Decl) {
switch (X->getKind()) {
case NodeKind::BindPattern:
{
auto Y = static_cast<BindPattern*>(X);
addSymbol(getCanonicalText(Y->Name), Decl, SymbolKind::Var);
break;
}
case NodeKind::RecordPattern:
{
auto Y = static_cast<RecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) {
visitPattern(Field->Pattern, Decl);
} else if (Field->Name) {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
}
}
break;
}
case NodeKind::NamedRecordPattern:
{
auto Y = static_cast<NamedRecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) {
visitPattern(Field->Pattern, Decl);
} else if (Field->Name) {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
}
}
break;
}
case NodeKind::NamedTuplePattern:
{
auto Y = static_cast<NamedTuplePattern*>(X);
for (auto P: Y->Patterns) {
visitPattern(P, Decl);
}
break;
}
case NodeKind::NestedPattern:
{
auto Y = static_cast<NestedPattern*>(X);
visitPattern(Y->P, Decl);
break;
}
case NodeKind::TuplePattern:
{
auto Y = static_cast<TuplePattern*>(X);
for (auto [Element, Comma]: Y->Elements) {
visitPattern(Element, Decl);
}
break;
}
case NodeKind::ListPattern:
{
auto Y = static_cast<ListPattern*>(X);
for (auto [Element, Separator]: Y->Elements) {
visitPattern(Element, Decl);
}
break;
}
case NodeKind::LiteralPattern:
break;
default:
ZEN_UNREACHABLE
}
}
Node* Scope::lookupDirect(SymbolPath Path, SymbolKind Kind) {
ZEN_ASSERT(Path.Modules.empty());
auto Match = Mapping.find(Path.Name);
if (Match != Mapping.end() && std::get<1>(Match->second) == Kind) {
return std::get<0>(Match->second);
}
return nullptr;
}
Node* Scope::lookup(SymbolPath Path, SymbolKind Kind) {
ZEN_ASSERT(Path.Modules.empty());
auto Curr = this;
do {
auto Found = Curr->lookupDirect(Path, Kind);
if (Found) {
return Found;
}
Curr = Curr->getParentScope();
} while (Curr != nullptr);
return nullptr;
}
Scope* Scope::getParentScope() {
if (Source->Parent == nullptr) {
return nullptr;
}
return Source->Parent->getScope();
}
const SourceFile* Node::getSourceFile() const {
const Node* CurrNode = this;
for (;;) {
@ -503,29 +293,11 @@ Token* WrappedOperator::getLastToken() const {
}
Token* BindPattern::getFirstToken() const {
switch (Name->getKind()) {
case NodeKind::Identifier:
return static_cast<Identifier*>(Name);
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(Name);
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(Name)->LParen;
default:
ZEN_UNREACHABLE
}
return Name;
}
Token* BindPattern::getLastToken() const {
switch (Name->getKind()) {
case NodeKind::Identifier:
return static_cast<Identifier*>(Name);
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(Name);
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(Name)->RParen;
default:
ZEN_UNREACHABLE
}
return Name;
}
Token* LiteralPattern::getFirstToken() const {
@ -608,26 +380,26 @@ Token* ReferenceExpression::getFirstToken() const {
if (!ModulePath.empty()) {
return std::get<0>(ModulePath.front());
}
switch (Name->getKind()) {
switch (Name.getKind()) {
case NodeKind::Identifier:
return static_cast<Identifier*>(Name);
return Name.asIdentifier();
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(Name);
return Name.asIdentifierAlt();
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(Name)->LParen;
return Name.asWrappedOperator()->getFirstToken();
default:
ZEN_UNREACHABLE
}
}
Token* ReferenceExpression::getLastToken() const {
switch (Name->getKind()) {
switch (Name.getKind()) {
case NodeKind::Identifier:
return static_cast<Identifier*>(Name);
return Name.asIdentifier();
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(Name);
return Name.asIdentifierAlt();
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(Name)->RParen;
return Name.asWrappedOperator()->getLastToken();
default:
ZEN_UNREACHABLE
}
@ -805,7 +577,7 @@ Token* LetExprBody::getLastToken() const {
return Expression->getLastToken();
}
Token* LetDeclaration::getFirstToken() const {
Token* PrefixFunctionDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
@ -815,17 +587,97 @@ Token* LetDeclaration::getFirstToken() const {
return LetKeyword;
}
Token* LetDeclaration::getLastToken() const {
Token* PrefixFunctionDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
if (Params.size()) {
return Param->getLastToken();
}
Token* InfixFunctionDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
if (ForeignKeyword) {
return ForeignKeyword;
}
return LetKeyword;
}
Token* InfixFunctionDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
return Right->getLastToken();
}
Token* SuffixFunctionDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
if (ForeignKeyword) {
return ForeignKeyword;
}
return LetKeyword;
}
Token* SuffixFunctionDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
return Name.getLastToken();
}
Token* NamedFunctionDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
if (ForeignKeyword) {
return ForeignKeyword;
}
return LetKeyword;
}
Token* NamedFunctionDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
if (!Params.empty()) {
return Params.back()->getLastToken();
}
return Name.getLastToken();
}
Token* VariableDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
if (ForeignKeyword) {
return ForeignKeyword;
}
return LetKeyword;
}
Token* VariableDeclaration::getLastToken() const {
return Pattern->getLastToken();
if (TypeAssert) {
return TypeAssert->getLastToken();
}
if (Body) {
return Body->getLastToken();
}
}
Token* RecordDeclarationField::getFirstToken() const {
@ -1093,23 +945,81 @@ std::string InstanceKeyword::getText() const {
return "instance";
}
ByteString getCanonicalText(const Symbol* N) {
ByteString Identifier::getCanonicalText() const {
return Text;
}
ByteString IdentifierAlt::getCanonicalText() const {
return Text;
}
ByteString CustomOperator::getCanonicalText() const {
return Text;
}
ByteString Symbol::getCanonicalText() const {
switch (N->getKind()) {
case NodeKind::Identifier:
return static_cast<const Identifier*>(N)->Text;
return static_cast<const Identifier*>(N)->getCanonicalText();
case NodeKind::IdentifierAlt:
return static_cast<const IdentifierAlt*>(N)->Text;
return static_cast<const IdentifierAlt*>(N)->getCanonicalText();
case NodeKind::CustomOperator:
return static_cast<const CustomOperator*>(N)->Text;
return static_cast<const CustomOperator*>(N)->getCanonicalText();
case NodeKind::VBar:
return static_cast<const VBar*>(N)->getText();
case NodeKind::WrappedOperator:
return static_cast<const WrappedOperator*>(N)->getOperator()->getText();
return static_cast<const WrappedOperator*>(N)->getCanonicalText();
default:
ZEN_UNREACHABLE
}
}
Token* Symbol::getFirstToken() const {
switch (N->getKind()) {
case NodeKind::Identifier:
return static_cast<Identifier*>(N);
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(N);
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(N)->getFirstToken();
default:
ZEN_UNREACHABLE
}
}
Token* Symbol::getLastToken() const {
switch (N->getKind()) {
case NodeKind::Identifier:
return static_cast<Identifier*>(N);
case NodeKind::IdentifierAlt:
return static_cast<IdentifierAlt*>(N);
case NodeKind::WrappedOperator:
return static_cast<WrappedOperator*>(N)->getLastToken();
default:
ZEN_UNREACHABLE
}
}
ByteString Operator::getCanonicalText() const {
switch (N->getKind()) {
case NodeKind::CustomOperator:
return static_cast<const CustomOperator*>(N)->getCanonicalText();
case NodeKind::VBar:
return static_cast<const VBar*>(N)->getText();
default:
ZEN_UNREACHABLE
}
}
Token* Operator::getFirstToken() const {
return static_cast<Token*>(N);
}
Token* Operator::getLastToken() const {
return static_cast<Token*>(N);
}
LiteralValue StringLiteral::getValue() {
return Text;
}
@ -1121,9 +1031,9 @@ LiteralValue IntegerLiteral::getValue() {
SymbolPath ReferenceExpression::getSymbolPath() const {
std::vector<ByteString> ModuleNames;
for (auto [Name, Dot]: ModulePath) {
ModuleNames.push_back(getCanonicalText(Name));
ModuleNames.push_back(Name->getCanonicalText());
}
return SymbolPath { ModuleNames, getCanonicalText(Name) };
return SymbolPath { ModuleNames, Name.getCanonicalText() };
}
}

View file

@ -251,9 +251,9 @@ void Checker::forwardDeclare(Node* X) {
inferTypeExpression(TE);
}
auto Match = InstanceMap.find(getCanonicalText(Decl->Name));
auto Match = InstanceMap.find(Decl->Name->getCanonicalText());
if (Match == InstanceMap.end()) {
InstanceMap.emplace(getCanonicalText(Decl->Name), std::vector { Decl });
InstanceMap.emplace(Decl->Name->getCanonicalText(), std::vector { Decl });
} else {
Match->second.push_back(Decl);
}
@ -265,13 +265,15 @@ void Checker::forwardDeclare(Node* X) {
break;
}
case NodeKind::LetDeclaration:
case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration:
case NodeKind::NamedFunctionDeclaration:
break;
case NodeKind::VariableDeclaration:
{
// Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression()
auto Decl = static_cast<LetDeclaration*>(X);
if (!Decl->isVariable()) {
break;
}
auto Decl = static_cast<VariableDeclaration*>(X);
Type* Ty;
if (Decl->TypeAssert) {
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression);
@ -290,13 +292,13 @@ void Checker::forwardDeclare(Node* X) {
std::vector<Type*> Vars;
for (auto TE: Decl->TVs) {
auto TV = createRigidVar(getCanonicalText(TE->Name));
auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV);
Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type);
Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
Vars.push_back(TV);
}
Type* Ty = createConType(getCanonicalText(Decl->Name));
Type* Ty = createConType(Decl->Name->getCanonicalText());
// Build the type that is actually returned by constructor functions
auto RetTy = Ty;
@ -305,7 +307,7 @@ void Checker::forwardDeclare(Node* X) {
}
// Must be added early so we can create recursive types
Decl->Ctx->Parent->Env.add(getCanonicalText(Decl->Name), new Forall(Ty), SymKind::Type);
Decl->Ctx->Parent->Env.add(Decl->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
for (auto Member: Decl->Members) {
switch (Member->getKind()) {
@ -318,7 +320,7 @@ void Checker::forwardDeclare(Node* X) {
ParamTypes.push_back(inferTypeExpression(Element, false));
}
Decl->Ctx->Parent->Env.add(
getCanonicalText(TupleMember->Name),
TupleMember->Name->getCanonicalText(),
new Forall(
Decl->Ctx->TVs,
Decl->Ctx->Constraints,
@ -351,13 +353,13 @@ void Checker::forwardDeclare(Node* X) {
std::vector<Type*> Vars;
for (auto TE: Decl->Vars) {
auto TV = createRigidVar(getCanonicalText(TE->Name));
auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV);
Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type);
Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
Vars.push_back(TV);
}
auto Name = getCanonicalText(Decl->Name);
auto Name = Decl->Name->getCanonicalText();
auto Ty = createConType(Name);
// Must be added early so we can create recursive types
@ -373,7 +375,7 @@ void Checker::forwardDeclare(Node* X) {
for (auto Field: Decl->Fields) {
FieldsTy = new Type(
TField(
getCanonicalText(Field->Name),
Field->Name->getCanonicalText(),
new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))),
FieldsTy
)
@ -435,13 +437,11 @@ void Checker::initialize(Node* N) {
Contexts.pop();
}
void visitLetDeclaration(LetDeclaration* Let) {
if (Let->isFunction()) {
Let->Ctx = createDerivedContext();
Contexts.push(Let->Ctx);
visitEachChild(Let);
Contexts.pop();
}
void visitFunctionDeclaration(FunctionDeclaration* Func) {
Func->Ctx = createDerivedContext();
Contexts.push(Func->Ctx);
visitEachChild(Func);
Contexts.pop();
}
// void visitVariableDeclaration(VariableDeclaration* Var) {
@ -456,22 +456,18 @@ void Checker::initialize(Node* N) {
}
void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) {
if (!Let->isFunction()) {
return;
}
void Checker::forwardDeclareFunctionDeclaration(FunctionDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) {
// std::cerr << "declare " << Let->getNameAsString() << std::endl;
setContext(Let->Ctx);
auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) {
auto Id = getCanonicalText(Class->Name);
auto Id = Class->Name->getCanonicalText();
auto Ctx = &getContext();
std::vector<Type*> Out;
for (auto TE: Class->TypeVars) {
auto Name = getCanonicalText(TE->Name);
auto Name = TE->Name->getCanonicalText();
auto TV = IsRigid ? createRigidVar(Name) : createTypeVar();
TV->asVar().Context.emplace(Id);
Ctx->Env.add(Name, new Forall(TV), SymKind::Type);
@ -493,8 +489,8 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs,
// Otherwise, the type is not further specified and we create a new
// unification variable.
Type* Ty;
if (Let->TypeAssert) {
Ty = inferTypeExpression(Let->TypeAssert->TypeExpression);
if (Let->hasTypeAssert()) {
Ty = inferTypeExpression(Let->getTypeAssert()->TypeExpression);
} else {
Ty = createTypeVar();
}
@ -507,9 +503,33 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs,
if (Let->isInstance()) {
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
auto Class = cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, getCanonicalText(Instance->Name) }, SymbolKind::Class));
// TODO check if `Class` is nullptr
auto SigLet = cast<LetDeclaration>(Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var));
auto Class = cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
if (Class == nullptr) {
// TODO print diagnostic
// DE.add<TypeclassNotFoundDiagnostic>(Instance->Name->getCanonicalText());
goto after_isinstance;
}
auto Decl = Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var);
if (Decl == nullptr) {
// TODO print diagnostic
// DE.add<MethodNotFoundInTypeclass>(Let->getNameAsStrings), Let->getName());
goto after_isinstance;
}
if (!isa<VariableDeclaration>(Decl)) {
// TODO print diagnostic
// DE.add<MustBeVariableDeclaration>(Decl);
goto after_isinstance;
}
auto FuncDecl = cast<FunctionDeclaration>(Decl);
auto Params = addClassVars(Class, false);
@ -536,23 +556,25 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs,
// It would be very strange if there was no type assert in the type
// class let-declaration but we rather not let the compiler crash if that happens.
if (SigLet->TypeAssert) {
if (FuncDecl->hasTypeAssert()) {
// Note that we can't do SigLet->TypeAssert->TypeExpression->getType()
// because we need to re-generate the type within the local context of
// this let-declaration.
// TODO make CEqual accept multiple nodes
makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let);
makeEqual(Ty, inferTypeExpression(FuncDecl->getTypeAssert()->TypeExpression), Let);
}
}
if (Let->Body) {
switch (Let->Body->getKind()) {
after_isinstance:
if (Let->hasBody()) {
switch (Let->getBody()->getKind()) {
case NodeKind::LetExprBody:
break;
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(Let->Body);
auto Block = static_cast<LetBlockBody*>(Let->getBody());
Let->Ctx->ReturnType = createTypeVar();
for (auto Element: Block->Elements) {
forwardDeclare(Element);
@ -570,11 +592,7 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs,
}
void Checker::inferFunctionDeclaration(LetDeclaration* Decl) {
if (!Decl->isFunction()) {
return;
}
void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) {
// std::cerr << "infer " << Decl->getNameAsString() << std::endl;
@ -584,21 +602,21 @@ void Checker::inferFunctionDeclaration(LetDeclaration* Decl) {
std::vector<Type*> ParamTypes;
Type* RetType;
for (auto Param: Decl->Params) {
for (auto Param: Decl->getParams()) {
ParamTypes.push_back(inferPattern(Param->Pattern));
}
if (Decl->Body) {
switch (Decl->Body->getKind()) {
if (Decl->hasBody()) {
switch (Decl->getBody()->getKind()) {
case NodeKind::LetExprBody:
{
auto Expr = static_cast<LetExprBody*>(Decl->Body);
auto Expr = static_cast<LetExprBody*>(Decl->getBody());
RetType = inferExpression(Expr->Expression);
break;
}
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(Decl->Body);
auto Block = static_cast<LetBlockBody*>(Decl->getBody());
RetType = Decl->Ctx->ReturnType;
for (auto Element: Block->Elements) {
infer(Element);
@ -680,29 +698,34 @@ void Checker::infer(Node* N) {
break;
}
case NodeKind::LetDeclaration:
case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration:
case NodeKind::NamedFunctionDeclaration:
{
// Function declarations are handled separately in inferFunctionDeclaration()
auto Decl = static_cast<LetDeclaration*>(N);
auto Decl = static_cast<FunctionDeclaration*>(N);
if (Decl->Visited) {
break;
}
if (Decl->isFunction()) {
Decl->IsCycleActive = true;
Decl->Visited = true;
inferFunctionDeclaration(Decl);
Decl->IsCycleActive = false;
} else if (Decl->isVariable()) {
auto Ty = Decl->getType();
if (Decl->Body) {
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
auto E = static_cast<LetExprBody*>(Decl->Body);
auto Ty2 = inferExpression(E->Expression);
makeEqual(Ty, Ty2, Decl);
}
auto Ty3 = inferPattern(Decl->Pattern);
makeEqual(Ty, Ty3, Decl);
Decl->IsCycleActive = true;
Decl->Visited = true;
inferFunctionDeclaration(Decl);
Decl->IsCycleActive = false;
break;
}
case NodeKind::VariableDeclaration:
{
auto Decl = static_cast<VariableDeclaration*>(N);
auto Ty = Decl->getType();
if (Decl->Body) {
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
auto E = static_cast<LetExprBody*>(Decl->Body);
auto Ty2 = inferExpression(E->Expression);
makeEqual(Ty, Ty2, Decl);
}
auto Ty3 = inferPattern(Decl->Pattern);
makeEqual(Ty, Ty3, Decl);
break;
}
@ -801,7 +824,7 @@ void Checker::inferConstraintExpression(ConstraintExpression* C) {
std::vector<Type*> Types;
for (auto TE: D->TEs) {
auto Ty = inferTypeExpression(TE);
Ty->asVar().Provided->emplace(getCanonicalText(D->Name));
Ty->asVar().Provided->emplace(D->Name->getCanonicalText());
Types.push_back(Ty);
}
break;
@ -824,10 +847,10 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
case NodeKind::ReferenceTypeExpression:
{
auto RefTE = static_cast<ReferenceTypeExpression*>(N);
auto Scm = lookup(getCanonicalText(RefTE->Name), SymKind::Type);
auto Scm = lookup(RefTE->Name->getCanonicalText(), SymKind::Type);
Type* Ty;
if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(RefTE->Name), RefTE->Name);
DE.add<BindingNotFoundDiagnostic>(RefTE->Name->getCanonicalText(), RefTE->Name);
Ty = createTypeVar();
} else {
Ty = instantiate(Scm, RefTE);
@ -850,13 +873,13 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
case NodeKind::VarTypeExpression:
{
auto VarTE = static_cast<VarTypeExpression*>(N);
auto Ty = lookupMono(getCanonicalText(VarTE->Name), SymKind::Type);
auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type);
if (Ty == nullptr) {
if (!AutoVars || Config.typeVarsRequireForall()) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(VarTE->Name), VarTE->Name);
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
}
Ty = createRigidVar(getCanonicalText(VarTE->Name));
addBinding(getCanonicalText(VarTE->Name), new Forall(Ty), SymKind::Type);
Ty = createRigidVar(VarTE->Name->getCanonicalText());
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
}
ZEN_ASSERT(Ty->isVar());
N->setType(Ty);
@ -868,7 +891,7 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
auto RecTE = static_cast<RecordTypeExpression*>(N);
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil());
for (auto [Field, Comma]: RecTE->Fields) {
Ty = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty));
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty));
}
N->setType(Ty);
return Ty;
@ -980,7 +1003,7 @@ Type* Checker::inferExpression(Expression* X) {
Ty = new Type(TNil());
for (auto [Field, Comma]: Record->Fields) {
Ty = new Type(TField(
getCanonicalText(Field->Name),
Field->Name->getCanonicalText(),
new Type(TPresent(inferExpression(Field->getExpression()))),
Ty
));
@ -999,11 +1022,12 @@ Type* Checker::inferExpression(Expression* X) {
case NodeKind::ReferenceExpression:
{
auto Ref = static_cast<ReferenceExpression*>(X);
auto Name = Ref->Name.getCanonicalText();
ZEN_ASSERT(Ref->ModulePath.empty());
if (Ref->Name->is<IdentifierAlt>()) {
auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var);
if (Ref->Name.isIdentifierAlt()) {
auto Scm = lookup(Name, SymKind::Var);
if (!Scm) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(Ref->Name), Ref->Name);
DE.add<BindingNotFoundDiagnostic>(Name, Ref->Name);
Ty = createTypeVar();
break;
}
@ -1012,12 +1036,12 @@ Type* Checker::inferExpression(Expression* X) {
}
auto Target = Ref->getScope()->lookup(Ref->getSymbolPath());
if (!Target) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(Ref->Name), Ref->Name);
DE.add<BindingNotFoundDiagnostic>(Name, Ref->Name);
Ty = createTypeVar();
break;
}
if (Target->getKind() == NodeKind::LetDeclaration) {
auto Let = static_cast<LetDeclaration*>(Target);
if (isa<FunctionDeclaration>(Target)) {
auto Let = static_cast<FunctionDeclaration*>(Target);
if (Let->IsCycleActive) {
Ty = Let->getType();
break;
@ -1026,7 +1050,7 @@ Type* Checker::inferExpression(Expression* X) {
infer(Let);
}
}
auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var);
auto Scm = lookup(Name, SymKind::Var);
ZEN_ASSERT(Scm);
Ty = instantiate(Scm, X);
break;
@ -1048,9 +1072,9 @@ Type* Checker::inferExpression(Expression* X) {
case NodeKind::InfixExpression:
{
auto Infix = static_cast<InfixExpression*>(X);
auto Scm = lookup(Infix->Operator->getText(), SymKind::Var);
auto Scm = lookup(Infix->Operator.getCanonicalText(), SymKind::Var);
if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(Infix->Operator->getText(), Infix->Operator);
DE.add<BindingNotFoundDiagnostic>(Infix->Operator.getCanonicalText(), Infix->Operator);
Ty = createTypeVar();
break;
}
@ -1091,7 +1115,7 @@ Type* Checker::inferExpression(Expression* X) {
auto K = static_cast<Identifier*>(Member->Name);
Ty = createTypeVar();
auto RestTy = createTypeVar();
makeEqual(new Type(TField(getCanonicalText(K), Ty, RestTy)), ExprTy, Member);
makeEqual(new Type(TField(K->getCanonicalText(), Ty, RestTy)), ExprTy, Member);
break;
}
default:
@ -1138,20 +1162,20 @@ Type* Checker::inferPattern(
{
auto P = static_cast<BindPattern*>(Pattern);
auto Ty = createTypeVar();
addBinding(getCanonicalText(P->Name), new Forall(TVs, Constraints, Ty), SymKind::Var);
addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty), SymKind::Var);
return Ty;
}
case NodeKind::NamedTuplePattern:
{
auto P = static_cast<NamedTuplePattern*>(Pattern);
auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var);
auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
std::vector<Type*> ElementTypes;
for (auto P2: P->Patterns) {
ElementTypes.push_back(inferPattern(P2, Constraints, TVs));
}
if (!Scm) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(P->Name), P->Name);
DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
return createTypeVar();
}
auto Ty = instantiate(Scm, P);
@ -1181,9 +1205,9 @@ Type* Checker::inferPattern(
FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
} else {
FieldTy = createTypeVar();
addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
}
RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy));
RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy));
}
return RecordTy;
}
@ -1191,9 +1215,9 @@ Type* Checker::inferPattern(
case NodeKind::NamedRecordPattern:
{
auto P = static_cast<NamedRecordPattern*>(Pattern);
auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var);
auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(getCanonicalText(P->Name), P->Name);
DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
return createTypeVar();
}
auto RestField = getRestField(P->Fields);
@ -1214,9 +1238,9 @@ Type* Checker::inferPattern(
FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
} else {
FieldTy = createTypeVar();
addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
}
RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy));
RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy));
}
auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar();
@ -1287,7 +1311,14 @@ void Checker::populate(SourceFile* SF) {
std::stack<Node*> Stack;
void visitLetDeclaration(LetDeclaration* N) {
void visitFunctionDeclaration(FunctionDeclaration* N) {
RefGraph.addVertex(N);
Stack.push(N);
visitEachChild(N);
Stack.pop();
}
void visitVariableDeclaration(VariableDeclaration* N) {
RefGraph.addVertex(N);
Stack.push(N);
visitEachChild(N);
@ -1295,22 +1326,26 @@ void Checker::populate(SourceFile* SF) {
}
void visitReferenceExpression(ReferenceExpression* N) {
auto Y = static_cast<ReferenceExpression*>(N);
auto Def = Y->getScope()->lookup(Y->getSymbolPath());
// Name lookup failures will be reported directly in inferExpression().
if (Def == nullptr || Def->getKind() != NodeKind::LetDeclaration) {
auto Ref = static_cast<ReferenceExpression*>(N);
auto Def = Ref->getScope()->lookup(Ref->getSymbolPath());
if (Def == nullptr) {
// Name lookup failures will be reported directly in inferExpression().
return;
}
ZEN_ASSERT(isa<FunctionDeclaration>(Def) || isa<VariableDeclaration>(Def) || isa<Parameter>(Def));
// This case ensures that a deeply nested structure that references a
// parameter of a parent node but is not referenced itself is correctly handled.
// Note that the edge goes from the parent let to the parameter. This is normal.
if (Def->getKind() == NodeKind::Parameter) {
RefGraph.addEdge(Stack.top(), Def->Parent);
// if (Def->getKind() == NodeKind::Parameter) {
// RefGraph.addEdge(Stack.top(), Def->Parent);
// return;
// }
if (Stack.empty()) {
// An empty stack means we are traversing the toplevel of the source
// file, in which case we don't have anyting to connect with.
return;
}
if (!Stack.empty()) {
RefGraph.addEdge(Def, Stack.top());
}
RefGraph.addEdge(Def, Stack.top());
}
};
@ -1353,10 +1388,10 @@ void Checker::check(SourceFile *SF) {
auto TVs = new TVSet;
auto Constraints = new ConstraintSet;
for (auto N: Nodes) {
if (N->getKind() != NodeKind::LetDeclaration) {
if (!isa<FunctionDeclaration>(N)) {
continue;
}
auto Decl = static_cast<LetDeclaration*>(N);
auto Decl = static_cast<FunctionDeclaration*>(N);
forwardDeclareFunctionDeclaration(Decl, TVs, Constraints);
}
}

View file

@ -155,7 +155,12 @@ static std::string describe(NodeKind Type) {
return "'class'";
case NodeKind::InstanceKeyword:
return "'instance'";
case NodeKind::LetDeclaration:
case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration:
case NodeKind::NamedFunctionDeclaration:
return "a let-declaration";
case NodeKind::VariableDeclaration:
return "a let-declaration";
case NodeKind::CallExpression:
return "a call-expression";

View file

@ -11,7 +11,7 @@ Value Evaluator::evaluateExpression(Expression* X, Env& Env) {
case NodeKind::ReferenceExpression:
{
auto RE = static_cast<ReferenceExpression*>(X);
return Env.lookup(getCanonicalText(RE->Name));
return Env.lookup(RE->Name.getCanonicalText());
// auto Decl = RE->getScope()->lookup(RE->getSymbolPath());
// ZEN_ASSERT(Decl && Decl->getKind() == NodeKind::FunctionDeclaration);
// return static_cast<FunctionDeclaration*>(Decl);
@ -48,7 +48,7 @@ void Evaluator::assignPattern(Pattern* P, Value& V, Env& E) {
case NodeKind::BindPattern:
{
auto BP = static_cast<BindPattern*>(P);
E.add(getCanonicalText(BP->Name), V);
E.add(BP->Name->getCanonicalText(), V);
break;
}
default:
@ -62,12 +62,12 @@ Value Evaluator::apply(Value Op, std::vector<Value> Args) {
{
auto Fn = Op.getDeclaration();
Env NewEnv;
for (auto [Param, Arg]: zen::zip(Fn->Params, Args)) {
for (auto [Param, Arg]: zen::zip(Fn->getParams(), Args)) {
assignPattern(Param->Pattern, Arg, NewEnv);
}
switch (Fn->Body->getKind()) {
switch (Fn->getBody()->getKind()) {
case NodeKind::LetExprBody:
return evaluateExpression(static_cast<LetExprBody*>(Fn->Body)->Expression, NewEnv);
return evaluateExpression(static_cast<LetExprBody*>(Fn->getBody())->Expression, NewEnv);
default:
ZEN_UNREACHABLE
}
@ -98,23 +98,28 @@ void Evaluator::evaluate(Node* N, Env& E) {
evaluateExpression(ES->Expression, E);
break;
}
case NodeKind::LetDeclaration:
case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration:
case NodeKind::NamedFunctionDeclaration:
{
auto Decl = static_cast<LetDeclaration*>(N);
if (Decl->isFunction()) {
E.add(Decl->getNameAsString(), Decl);
} else {
Value V;
if (Decl->Body) {
switch (Decl->Body->getKind()) {
case NodeKind::LetExprBody:
{
auto Body = static_cast<LetExprBody*>(Decl->Body);
V = evaluateExpression(Body->Expression, E);
}
default:
ZEN_UNREACHABLE
auto Decl = static_cast<FunctionDeclaration*>(N);
E.add(Decl->getNameAsString(), Decl);
break;
}
case NodeKind::VariableDeclaration:
{
auto Decl = static_cast<VariableDeclaration*>(N);
Value V;
if (Decl->Body) {
switch (Decl->Body->getKind()) {
case NodeKind::LetExprBody:
{
auto Body = static_cast<LetExprBody*>(Decl->Body);
V = evaluateExpression(Body->Expression, E);
}
default:
ZEN_UNREACHABLE
}
}
break;

View file

@ -32,16 +32,6 @@
namespace bolt {
bool isOperator(Token* T) {
switch (T->getKind()) {
case NodeKind::VBar:
case NodeKind::CustomOperator:
return true;
default:
return false;
}
}
std::optional<OperatorInfo> OperatorTable::getInfix(Token* T) {
auto Match = Mapping.find(T->getText());
if (Match == Mapping.end() || !Match->second.isInfix()) {
@ -828,7 +818,7 @@ Expression* Parser::parsePrimitiveExpression() {
DE.add<UnexpectedTokenDiagnostic>(File, T3, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt });
return nullptr;
}
return new ReferenceExpression(Annotations, ModulePath, static_cast<Symbol*>(T3));
return new ReferenceExpression(Annotations, ModulePath, Symbol::from_raw_node(T3));
}
case NodeKind::LParen:
{
@ -1025,7 +1015,7 @@ Expression* Parser::parseInfixOperatorAfterExpression(Expression* Left, int MinP
}
Right = NewRight;
}
Left = new InfixExpression(Left, T0, Right);
Left = new InfixExpression(Left, Operator::from_raw_node(T0), Right);
}
return Left;
}
@ -1141,17 +1131,31 @@ IfStatement* Parser::parseIfStatement() {
return new IfStatement(Parts);
}
LetDeclaration* Parser::parseLetDeclaration() {
enum class LetMode {
Prefix,
Infix,
Suffix,
Wrapped,
VarOrNamed,
};
Node* Parser::parseLetDeclaration() {
auto Annotations = parseAnnotations();
PubKeyword* Pub = nullptr;
ForeignKeyword* Foreign = nullptr;
LetKeyword* Let;
MutKeyword* Mut = nullptr;
Operator Op;
Symbol Sym;
Pattern* Name;
Parameter* Param;
Parameter* Left;
Parameter* Right;
std::vector<Parameter*> Params;
TypeAssert* TA = nullptr;
LetBody* Body = nullptr;
LetMode Mode;
auto T0 = Tokens.get();
if (T0->getKind() == NodeKind::PubKeyword) {
@ -1183,38 +1187,46 @@ LetDeclaration* Parser::parseLetDeclaration() {
auto T2 = Tokens.peek(0);
auto T3 = Tokens.peek(1);
auto T4 = Tokens.peek(2);
if (isOperator(T2)) {
if (isa<Operator>(T2)) {
// Prefix function declaration
Tokens.get();
auto P1 = parseNarrowPattern();
Params.push_back(new Parameter(P1, nullptr));
Name = new BindPattern(T2);
Param = new Parameter(P1, nullptr);
Op = Operator::from_raw_node(T2);
Mode = LetMode::Prefix;
goto after_params;
} else if (isOperator(T3) && (T4->getKind() == NodeKind::Colon || T4->getKind() == NodeKind::Equals || T4->getKind() == NodeKind::BlockStart || T4->getKind() == NodeKind::LineFoldEnd)) {
} else if (isa<Operator>(T3) && (T4->getKind() == NodeKind::Colon || T4->getKind() == NodeKind::Equals || T4->getKind() == NodeKind::BlockStart || T4->getKind() == NodeKind::LineFoldEnd)) {
// Sufffix function declaration
auto P1 = parseNarrowPattern();
Params.push_back(new Parameter(P1, nullptr));
Param = new Parameter(P1, nullptr);
Tokens.get();
Name = new BindPattern(T3);
Op = Operator::from_raw_node(T3);
Mode = LetMode::Suffix;
goto after_params;
} else if (T2->getKind() == NodeKind::LParen && isOperator(T3) && T4->getKind() == NodeKind::RParen) {
} else if (T2->getKind() == NodeKind::LParen && isa<Operator>(T3) && T4->getKind() == NodeKind::RParen) {
// Wrapped operator function declaration
Tokens.get();
Tokens.get();
Tokens.get();
Name = new BindPattern(
new WrappedOperator(
static_cast<class LParen*>(T2),
T3,
static_cast<class RParen*>(T3)
)
Sym = new WrappedOperator(
static_cast<class LParen*>(T2),
Operator::from_raw_node(T3),
static_cast<class RParen*>(T3)
);
} else if (isOperator(T3)) {
Mode = LetMode::Wrapped;
} else if (isa<Operator>(T3)) {
// Infix function declaration
auto P1 = parseNarrowPattern();
Params.push_back(new Parameter(P1, nullptr));
Left = new Parameter(P1, nullptr);
Tokens.get();
auto P2 = parseNarrowPattern();
Params.push_back(new Parameter(P2, nullptr));
Name = new BindPattern(T3);
Right = new Parameter(P2, nullptr);
Op = Operator::from_raw_node(T3);
Mode = LetMode::Infix;
goto after_params;
} else {
// Variable declaration or named function declaration
Mode = LetMode::VarOrNamed;
Name = parseNarrowPattern();
if (!Name) {
if (Pub) {
@ -1313,17 +1325,77 @@ after_params:
finish:
return new LetDeclaration(
Annotations,
Pub,
Foreign,
Let,
Mut,
Name,
Params,
TA,
Body
);
switch (Mode) {
case LetMode::Prefix:
return new PrefixFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Op,
Param,
TA,
Body
);
case LetMode::Suffix:
return new SuffixFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Param,
Op,
TA,
Body
);
case LetMode::Infix:
return new InfixFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Left,
Op,
Right,
TA,
Body
);
case LetMode::Wrapped:
return new NamedFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Sym,
Params,
TA,
Body
);
case LetMode::VarOrNamed:
if (Name->getKind() != NodeKind::BindPattern || Mut) {
// TODO assert Params is empty
return new VariableDeclaration(
Annotations,
Pub,
Foreign,
Let,
Mut,
Name,
TA,
Body
);
}
return new NamedFunctionDeclaration(
Annotations,
Pub,
Foreign,
Let,
Name->as<BindPattern>()->Name,
Params,
TA,
Body
);
}
}
Node* Parser::parseLetBodyElement() {