Refactor CST and enable typechecking of do-expressions

This commit is contained in:
Sam Vervaeck 2024-07-10 16:02:07 +02:00
parent 15dab8a7a8
commit 87bb0d0b10
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
12 changed files with 527 additions and 613 deletions

File diff suppressed because it is too large Load diff

View file

@ -99,10 +99,9 @@ public:
BOLT_GEN_CASE(PrefixExpression) BOLT_GEN_CASE(PrefixExpression)
BOLT_GEN_CASE(RecordExpressionField) BOLT_GEN_CASE(RecordExpressionField)
BOLT_GEN_CASE(RecordExpression) BOLT_GEN_CASE(RecordExpression)
BOLT_GEN_CASE(ExpressionStatement) BOLT_GEN_CASE(ReturnExpression)
BOLT_GEN_CASE(ReturnStatement) BOLT_GEN_CASE(IfExpression)
BOLT_GEN_CASE(IfStatement) BOLT_GEN_CASE(IfExpressionPart)
BOLT_GEN_CASE(IfStatementPart)
BOLT_GEN_CASE(TypeAssert) BOLT_GEN_CASE(TypeAssert)
BOLT_GEN_CASE(Parameter) BOLT_GEN_CASE(Parameter)
BOLT_GEN_CASE(LetBlockBody) BOLT_GEN_CASE(LetBlockBody)
@ -498,23 +497,15 @@ protected:
static_cast<D*>(this)->visitExpression(N); static_cast<D*>(this)->visitExpression(N);
} }
void visitStatement(Statement* N) { void visitReturnExpression(ReturnExpression* N) {
static_cast<D*>(this)->visitNode(N); static_cast<D*>(this)->visitExpression(N);
} }
void visitExpressionStatement(ExpressionStatement* N) { void visitIfExpression(IfExpression* N) {
static_cast<D*>(this)->visitStatement(N); static_cast<D*>(this)->visitExpression(N);
} }
void visitReturnStatement(ReturnStatement* N) { void visitIfExpressionPart(IfExpressionPart* N) {
static_cast<D*>(this)->visitStatement(N);
}
void visitIfStatement(IfStatement* N) {
static_cast<D*>(this)->visitStatement(N);
}
void visitIfStatementPart(IfStatementPart* N) {
static_cast<D*>(this)->visitNode(N); static_cast<D*>(this)->visitNode(N);
} }
@ -687,10 +678,9 @@ public:
BOLT_GEN_CHILD_CASE(PrefixExpression) BOLT_GEN_CHILD_CASE(PrefixExpression)
BOLT_GEN_CHILD_CASE(RecordExpressionField) BOLT_GEN_CHILD_CASE(RecordExpressionField)
BOLT_GEN_CHILD_CASE(RecordExpression) BOLT_GEN_CHILD_CASE(RecordExpression)
BOLT_GEN_CHILD_CASE(ExpressionStatement) BOLT_GEN_CHILD_CASE(ReturnExpression)
BOLT_GEN_CHILD_CASE(ReturnStatement) BOLT_GEN_CHILD_CASE(IfExpression)
BOLT_GEN_CHILD_CASE(IfStatement) BOLT_GEN_CHILD_CASE(IfExpressionPart)
BOLT_GEN_CHILD_CASE(IfStatementPart)
BOLT_GEN_CHILD_CASE(TypeAssert) BOLT_GEN_CHILD_CASE(TypeAssert)
BOLT_GEN_CHILD_CASE(Parameter) BOLT_GEN_CHILD_CASE(Parameter)
BOLT_GEN_CHILD_CASE(LetBlockBody) BOLT_GEN_CHILD_CASE(LetBlockBody)
@ -1163,14 +1153,7 @@ public:
BOLT_VISIT(N->RBrace); BOLT_VISIT(N->RBrace);
} }
void visitEachChild(ExpressionStatement* N) { void visitEachChild(ReturnExpression* N) {
for (auto A: N->Annotations) {
BOLT_VISIT(A);
}
BOLT_VISIT(N->Expression);
}
void visitEachChild(ReturnStatement* N) {
for (auto A: N->Annotations) { for (auto A: N->Annotations) {
BOLT_VISIT(A); BOLT_VISIT(A);
} }
@ -1178,13 +1161,13 @@ public:
BOLT_VISIT(N->E); BOLT_VISIT(N->E);
} }
void visitEachChild(IfStatement* N) { void visitEachChild(IfExpression* N) {
for (auto Part: N->Parts) { for (auto Part: N->Parts) {
BOLT_VISIT(Part); BOLT_VISIT(Part);
} }
} }
void visitEachChild(IfStatementPart* N) { void visitEachChild(IfExpressionPart* N) {
for (auto A: N->Annotations) { for (auto A: N->Annotations) {
BOLT_VISIT(A); BOLT_VISIT(A);
} }

View file

@ -89,15 +89,11 @@ class Checker {
Type* IntType; Type* IntType;
Type* BoolType; Type* BoolType;
Type* StringType; Type* StringType;
Type* UnitType;
public: public:
Checker(DiagnosticEngine& DE): Checker(DiagnosticEngine& DE);
DE(DE) {
IntType = new TCon("Int");
BoolType = new TCon("Bool");
StringType = new TCon("String");
}
Type* getIntType() const { Type* getIntType() const {
return IntType; return IntType;
@ -111,6 +107,10 @@ public:
return StringType; return StringType;
} }
Type* getUnitType() const {
return UnitType;
}
TVar* createTVar() { TVar* createTVar() {
return new TVar(); return new TVar();
} }

View file

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <concepts>
#include <cstdlib> #include <cstdlib>
#include "zen/config.hpp" #include "zen/config.hpp"
@ -36,17 +37,10 @@ public:
}; };
template<typename D, typename B> template<typename T>
D* cast(B* base) { concept HoldsKind = requires (T a) {
ZEN_ASSERT(D::classof(base)); { a.getKind() } -> std::convertible_to<decltype(T::Kind)>;
return static_cast<D*>(base); };
}
template<typename D, typename B>
const D* cast(const B* base) {
ZEN_ASSERT(D::classof(base));
return static_cast<const D*>(base);
}
template<typename D, typename T> template<typename D, typename T>
bool isa(const T* value) { bool isa(const T* value) {
@ -54,4 +48,22 @@ bool isa(const T* value) {
return D::classof(value); return D::classof(value);
} }
template<HoldsKind D, typename T>
bool isa(const T* value) {
ZEN_ASSERT(value != nullptr);
return D::Kind == value->getKind();
}
template<typename D, typename B>
D* cast(B* base) {
ZEN_ASSERT(isa<D>(base));
return static_cast<D*>(base);
}
template<typename D, typename B>
const D* cast(const B* base) {
ZEN_ASSERT(isa<D>(base));
return static_cast<const D*>(base);
}
} }

View file

@ -75,9 +75,7 @@ class Parser {
std::optional<std::vector<std::tuple<RecordPatternField*, Comma*>>> parseRecordPatternFields(); std::optional<std::vector<std::tuple<RecordPatternField*, Comma*>>> parseRecordPatternFields();
template<typename T> template<typename T>
T* expectToken() { T* expectToken();
return static_cast<T*>(expectToken(getNodeType<T>()));
}
Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence); Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence);
@ -115,18 +113,15 @@ public:
Parameter* parseParam(); Parameter* parseParam();
ReferenceExpression* parseReferenceExpression(); ReferenceExpression* parseReferenceExpression();
Expression* parseUnaryExpression(); Expression* parseUnaryExpression();
Expression* parseExpression(); Expression* parseExpression();
BlockExpression* parseBlockExpression(std::vector<Annotation*> Annotations = {});
Expression* parseCallExpression(); Expression* parseCallExpression();
IfExpression* parseIfExpression();
IfStatement* parseIfStatement(); ReturnExpression* parseReturnExpression();
ReturnStatement* parseReturnStatement(); Expression* parseExpressionStatement();
ExpressionStatement* parseExpressionStatement();
Node* parseLetBodyElement(); Node* parseLetBodyElement();

View file

@ -7,7 +7,7 @@ namespace bolt {
TextFile::TextFile(ByteString Path, ByteString Text): TextFile::TextFile(ByteString Path, ByteString Text):
Path(Path), Text(Text) { Path(Path), Text(Text) {
LineOffsets.push_back(0); LineOffsets.push_back(0);
for (size_t I = 0; I < Text.size(); I++) { for (std::size_t I = 0; I < Text.size(); I++) {
auto Chr = Text[I]; auto Chr = Text[I];
if (Chr == '\n') { if (Chr == '\n') {
LineOffsets.push_back(I+1); LineOffsets.push_back(I+1);
@ -16,16 +16,16 @@ TextFile::TextFile(ByteString Path, ByteString Text):
LineOffsets.push_back(Text.size()); LineOffsets.push_back(Text.size());
} }
size_t TextFile::getLineCount() const { std::size_t TextFile::getLineCount() const {
return LineOffsets.size()-1; return LineOffsets.size()-1;
} }
size_t TextFile::getStartOffsetOfLine(size_t Line) const { std::size_t TextFile::getStartOffsetOfLine(std::size_t Line) const {
ZEN_ASSERT(Line-1 < LineOffsets.size()); ZEN_ASSERT(Line-1 < LineOffsets.size());
return LineOffsets[Line-1]; return LineOffsets[Line-1];
} }
size_t TextFile::getEndOffsetOfLine(size_t Line) const { std::size_t TextFile::getEndOffsetOfLine(std::size_t Line) const {
ZEN_ASSERT(Line <= LineOffsets.size()); ZEN_ASSERT(Line <= LineOffsets.size());
if (Line == LineOffsets.size()) { if (Line == LineOffsets.size()) {
return Text.size(); return Text.size();
@ -33,9 +33,9 @@ size_t TextFile::getEndOffsetOfLine(size_t Line) const {
return LineOffsets[Line]; return LineOffsets[Line];
} }
size_t TextFile::getLine(size_t Offset) const { std::size_t TextFile::getLine(std::size_t Offset) const {
ZEN_ASSERT(Offset < Text.size()); ZEN_ASSERT(Offset < Text.size());
for (size_t I = 0; I < LineOffsets.size(); ++I) { for (std::size_t I = 0; I < LineOffsets.size(); ++I) {
if (LineOffsets[I] > Offset) { if (LineOffsets[I] > Offset) {
return I; return I;
} }
@ -43,7 +43,7 @@ size_t TextFile::getLine(size_t Offset) const {
ZEN_UNREACHABLE ZEN_UNREACHABLE
} }
size_t TextFile::getColumn(size_t Offset) const { std::size_t TextFile::getColumn(std::size_t Offset) const {
auto Line = getLine(Offset); auto Line = getLine(Offset);
auto StartOffset = getStartOffsetOfLine(Line); auto StartOffset = getStartOffsetOfLine(Line);
return Offset - StartOffset + 1 ; return Offset - StartOffset + 1 ;
@ -60,7 +60,7 @@ ByteString TextFile::getText() const {
const SourceFile* Node::getSourceFile() const { const SourceFile* Node::getSourceFile() const {
const Node* CurrNode = this; const Node* CurrNode = this;
for (;;) { for (;;) {
if (CurrNode->Kind == NodeKind::SourceFile) { if (CurrNode->K == NodeKind::SourceFile) {
return static_cast<const SourceFile*>(CurrNode); return static_cast<const SourceFile*>(CurrNode);
} }
CurrNode = CurrNode->Parent; CurrNode = CurrNode->Parent;
@ -70,7 +70,7 @@ const SourceFile* Node::getSourceFile() const {
SourceFile* Node::getSourceFile() { SourceFile* Node::getSourceFile() {
Node* CurrNode = this; Node* CurrNode = this;
for (;;) { for (;;) {
if (CurrNode->Kind == NodeKind::SourceFile) { if (CurrNode->K == NodeKind::SourceFile) {
return static_cast<SourceFile*>(CurrNode); return static_cast<SourceFile*>(CurrNode);
} }
CurrNode = CurrNode->Parent; CurrNode = CurrNode->Parent;
@ -508,42 +508,34 @@ Token* PrefixExpression::getLastToken() const {
return Argument->getLastToken(); return Argument->getLastToken();
} }
Token* ExpressionStatement::getFirstToken() const { Token* ReturnExpression::getFirstToken() const {
return Expression->getFirstToken();
}
Token* ExpressionStatement::getLastToken() const {
return Expression->getLastToken();
}
Token* ReturnStatement::getFirstToken() const {
return ReturnKeyword; return ReturnKeyword;
} }
Token* ReturnStatement::getLastToken() const { Token* ReturnExpression::getLastToken() const {
if (E) { if (E) {
return E->getLastToken(); return E->getLastToken();
} }
return ReturnKeyword; return ReturnKeyword;
} }
Token* IfStatementPart::getFirstToken() const { Token* IfExpressionPart::getFirstToken() const {
return Keyword; return Keyword;
} }
Token* IfStatementPart::getLastToken() const { Token* IfExpressionPart::getLastToken() const {
if (Elements.size()) { if (Elements.size()) {
return Elements.back()->getLastToken(); return Elements.back()->getLastToken();
} }
return BlockStart; return BlockStart;
} }
Token* IfStatement::getFirstToken() const { Token* IfExpression::getFirstToken() const {
ZEN_ASSERT(Parts.size()); ZEN_ASSERT(Parts.size());
return Parts.front()->getFirstToken(); return Parts.front()->getFirstToken();
} }
Token* IfStatement::getLastToken() const { Token* IfExpression::getLastToken() const {
ZEN_ASSERT(Parts.size()); ZEN_ASSERT(Parts.size());
return Parts.back()->getLastToken(); return Parts.back()->getLastToken();
} }
@ -1049,7 +1041,7 @@ SymbolPath ReferenceExpression::getSymbolPath() const {
return SymbolPath { ModuleNames, Name.getCanonicalText() }; return SymbolPath { ModuleNames, Name.getCanonicalText() };
} }
bool TypedNode::classof(Node* N) { bool TypedNode::classof(const Node* N) {
return Expression::classof(N) return Expression::classof(N)
|| TypeExpression::classof(N) || TypeExpression::classof(N)
|| FunctionDeclaration::classof(N) || FunctionDeclaration::classof(N)

View file

@ -78,7 +78,13 @@ Type* substituteType(Type* Ty, const TVSub& Sub) {
} }
} }
Checker::Checker(DiagnosticEngine& DE):
DE(DE) {
IntType = new TCon("Int");
BoolType = new TCon("Bool");
StringType = new TCon("String");
UnitType = new TCon("()");
}
Type* Checker::instantiate(TypeScheme* Scm) { Type* Checker::instantiate(TypeScheme* Scm) {
TVSub Sub; TVSub Sub;
@ -103,6 +109,23 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
switch (Expr->getKind()) { switch (Expr->getKind()) {
case NodeKind::BlockExpression:
{
auto E = static_cast<BlockExpression*>(Expr);
auto N = E->Elements.size();
for (std::size_t I = 0; I+1 < N; ++I) {
auto Element = E->Elements[I];
auto CC = inferElement(Env, Element, RetTy);
mergeTo(Out, CC);
}
auto Last = E->Elements[N-1];
auto [CC, ResTy] = inferExpr(Env, cast<Expression>(Last), RetTy);
mergeTo(Out, CC);
Ty = ResTy;
break;
}
case NodeKind::ReferenceExpression: case NodeKind::ReferenceExpression:
{ {
auto E = static_cast<ReferenceExpression*>(Expr); auto E = static_cast<ReferenceExpression*>(Expr);
@ -169,6 +192,21 @@ std::tuple<ConstraintSet, Type*> Checker::inferExpr(TypeEnv& Env, Expression* Ex
break; break;
} }
case NodeKind::ReturnExpression:
{
auto E = static_cast<ReturnExpression*>(Expr);
if (E->hasExpression()) {
auto [ValOut, ValTy] = inferExpr(Env, E->getExpression(), RetTy);
mergeTo(Out, ValOut);
// Since evaluation stops at the return expression, it can be matched with any type.
Out.push_back(new CTypesEqual { ValTy, RetTy, E });
} else {
Out.push_back(new CTypesEqual { getUnitType(), RetTy, E });
}
Ty = createTVar();
break;
}
// TODO LambdaExpression // TODO LambdaExpression
default: default:
@ -255,7 +293,7 @@ ConstraintSet Checker::inferFunctionDeclaration(TypeEnv& Env, FunctionDeclaratio
if (Body != nullptr) { if (Body != nullptr) {
// TODO elminate BlockBody and replace with BlockExpr // TODO elminate BlockBody and replace with BlockExpr
ZEN_ASSERT(Body->getKind() == NodeKind::LetExprBody); ZEN_ASSERT(Body->getKind() == NodeKind::LetExprBody);
auto [BodyOut, BodyTy] = inferExpr(NewEnv, static_cast<LetExprBody*>(Body)->Expression, RetTy); auto [BodyOut, BodyTy] = inferExpr(NewEnv, cast<LetExprBody>(Body)->Expression, RetTy);
mergeTo(Out, BodyOut); mergeTo(Out, BodyOut);
Out.push_back(new CTypesEqual(RetTy, BodyTy, Body)); Out.push_back(new CTypesEqual(RetTy, BodyTy, Body));
} }
@ -351,7 +389,7 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
V.visit(N); V.visit(N);
}; };
std::vector<Statement*> Stmts; std::vector<Node*> Stmts;
for (auto Element: Elements) { for (auto Element: Elements) {
if (isa<FunctionDeclaration>(Element)) { if (isa<FunctionDeclaration>(Element)) {
@ -367,7 +405,7 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
populate(M, M->getExpression()); populate(M, M->getExpression());
} }
} else { } else {
Stmts.push_back(cast<Statement>(Element)); Stmts.push_back(Element);
} }
} }
@ -407,6 +445,11 @@ ConstraintSet Checker::inferMany(TypeEnv& Env, std::vector<Node*>& Elements, Typ
ConstraintSet Checker::inferElement(TypeEnv& Env, Node* N, Type* RetTy) { ConstraintSet Checker::inferElement(TypeEnv& Env, Node* N, Type* RetTy) {
if (isa<Expression>(N)) {
auto [Out, Ty] = inferExpr(Env, cast<Expression>(N), RetTy);
return Out;
}
switch (N->getKind()) { switch (N->getKind()) {
case NodeKind::PrefixFunctionDeclaration: case NodeKind::PrefixFunctionDeclaration:
@ -415,16 +458,9 @@ ConstraintSet Checker::inferElement(TypeEnv& Env, Node* N, Type* RetTy) {
case NodeKind::NamedFunctionDeclaration: case NodeKind::NamedFunctionDeclaration:
return inferFunctionDeclaration(Env, static_cast<FunctionDeclaration*>(N)); return inferFunctionDeclaration(Env, static_cast<FunctionDeclaration*>(N));
case NodeKind::ExpressionStatement: case NodeKind::ReturnExpression:
{ {
auto M = static_cast<ExpressionStatement*>(N); auto M = static_cast<ReturnExpression*>(N);
auto [Out, _] = inferExpr(Env, M->Expression, RetTy);
return Out;
}
case NodeKind::ReturnStatement:
{
auto M = static_cast<ReturnStatement*>(N);
if (!M->hasExpression()) { if (!M->hasExpression()) {
return {}; return {};
} }

View file

@ -172,9 +172,9 @@ static std::string describe(NodeKind Type) {
return "a literal expression"; return "a literal expression";
case NodeKind::MemberExpression: case NodeKind::MemberExpression:
return "an accessor of a member"; return "an accessor of a member";
case NodeKind::IfStatement: case NodeKind::IfExpression:
return "an if-statement"; return "an if-statement";
case NodeKind::IfStatementPart: case NodeKind::IfExpressionPart:
return "a branch of an if-statement"; return "a branch of an if-statement";
case NodeKind::VariantDeclaration: case NodeKind::VariantDeclaration:
return "a variant"; return "a variant";

View file

@ -97,6 +97,10 @@ Value Evaluator::apply(Value Op, std::vector<Value> Args) {
} }
void Evaluator::evaluate(Node* N, Env& E) { void Evaluator::evaluate(Node* N, Env& E) {
if (isa<Expression>(N)) {
evaluateExpression(cast<Expression>(N), E);
return;
}
switch (N->getKind()) { switch (N->getKind()) {
case NodeKind::SourceFile: case NodeKind::SourceFile:
{ {
@ -106,12 +110,6 @@ void Evaluator::evaluate(Node* N, Env& E) {
} }
break; break;
} }
case NodeKind::ExpressionStatement:
{
auto ES = static_cast<ExpressionStatement*>(N);
evaluateExpression(ES->Expression, E);
break;
}
case NodeKind::PrefixFunctionDeclaration: case NodeKind::PrefixFunctionDeclaration:
case NodeKind::InfixFunctionDeclaration: case NodeKind::InfixFunctionDeclaration:
case NodeKind::SuffixFunctionDeclaration: case NodeKind::SuffixFunctionDeclaration:

View file

@ -85,6 +85,18 @@ Parser::Parser(TextFile& File, Stream<Token*>& S, DiagnosticEngine& DE):
ExprOperators.add("$", OperatorFlags_InfixR, 0); ExprOperators.add("$", OperatorFlags_InfixR, 0);
} }
template<typename T>
T* Parser::expectToken() {
auto Tok = Tokens.peek();
if (Tok->getKind() != T::Kind) {
DE.add<UnexpectedTokenDiagnostic>(File, Tok, std::vector<NodeKind> { T::Kind });
return nullptr;
}
Tokens.get();
return static_cast<T*>(Tok);
}
Token* Parser::peekFirstTokenAfterAnnotationsAndModifiers() { Token* Parser::peekFirstTokenAfterAnnotationsAndModifiers() {
std::size_t I = 0; std::size_t I = 0;
for (;;) { for (;;) {
@ -107,16 +119,6 @@ Token* Parser::peekFirstTokenAfterAnnotationsAndModifiers() {
} }
} }
Token* Parser::expectToken(NodeKind Kind) {
auto T = Tokens.peek();
if (T->getKind() != Kind) {
DE.add<UnexpectedTokenDiagnostic>(File, T, std::vector<NodeKind> { Kind });
return nullptr;
}
Tokens.get();
return T;
}
ListPattern* Parser::parseListPattern() { ListPattern* Parser::parseListPattern() {
auto LBracket = expectToken<class LBracket>(); auto LBracket = expectToken<class LBracket>();
if (!LBracket) { if (!LBracket) {
@ -818,7 +820,7 @@ Expression* Parser::parsePrimitiveExpression() {
ModulePath.push_back(std::make_tuple(static_cast<IdentifierAlt*>(T1), static_cast<class Dot*>(T2))); ModulePath.push_back(std::make_tuple(static_cast<IdentifierAlt*>(T1), static_cast<class Dot*>(T2)));
} }
auto T3 = Tokens.get(); auto T3 = Tokens.get();
if (!T3->is<Identifier>() && !T3->is<IdentifierAlt>()) { if (!isa<Identifier>(T3) && !isa<IdentifierAlt>(T3)) {
for (auto [Name, Dot]: ModulePath) { for (auto [Name, Dot]: ModulePath) {
Name->unref(); Name->unref();
Dot->unref(); Dot->unref();
@ -886,37 +888,11 @@ after_tuple_elements:
case NodeKind::MatchKeyword: case NodeKind::MatchKeyword:
return parseMatchExpression(); return parseMatchExpression();
case NodeKind::DoKeyword: case NodeKind::DoKeyword:
{ return parseBlockExpression();
Tokens.get(); case NodeKind::IfKeyword:
auto T1 = expectToken(NodeKind::BlockStart); return parseIfExpression();
if (!T1) { case NodeKind::ReturnKeyword:
BOLT_EACH_UNREF(Annotations); return parseReturnExpression();
T0->unref();
return nullptr;
}
std::vector<Node*> Elements;
for (;;) {
auto T2 = Tokens.peek();
if (T2->getKind() == NodeKind::BlockEnd) {
Tokens.get()->unref();
break;
}
auto Element = parseLetBodyElement();
if (Element == nullptr) {
BOLT_EACH_UNREF(Annotations);
T0->unref();
T1->unref();
BOLT_EACH_UNREF(Elements);
return nullptr;
}
Elements.push_back(Element);
}
return new BlockExpression {
static_cast<class DoKeyword*>(T0),
static_cast<BlockStart*>(T1),
Elements
};
}
case NodeKind::IntegerLiteral: case NodeKind::IntegerLiteral:
case NodeKind::StringLiteral: case NodeKind::StringLiteral:
Tokens.get(); Tokens.get();
@ -938,6 +914,42 @@ after_tuple_elements:
} }
} }
BlockExpression* Parser::parseBlockExpression(std::vector<Annotation*> Annotations) {
auto DoKeyword = expectToken<class DoKeyword>();
if (!DoKeyword) {
BOLT_EACH_UNREF(Annotations);
return nullptr;
}
auto BlockStart = expectToken<class BlockStart>();
if (!BlockStart) {
BOLT_EACH_UNREF(Annotations);
DoKeyword->unref();
return nullptr;
}
std::vector<Node*> Elements;
for (;;) {
auto T2 = Tokens.peek();
if (T2->getKind() == NodeKind::BlockEnd) {
Tokens.get()->unref();
break;
}
auto Element = parseLetBodyElement();
if (Element == nullptr) {
BOLT_EACH_UNREF(Annotations);
DoKeyword->unref();
BlockStart->unref();
BOLT_EACH_UNREF(Elements);
return nullptr;
}
Elements.push_back(Element);
}
return new BlockExpression {
DoKeyword,
BlockStart,
Elements
};
}
Expression* Parser::parseMemberExpression() { Expression* Parser::parseMemberExpression() {
auto E = parsePrimitiveExpression(); auto E = parsePrimitiveExpression();
if (!E) { if (!E) {
@ -1068,17 +1080,17 @@ Expression* Parser::parseExpression() {
return parseInfixOperatorAfterExpression(Left, 0); return parseInfixOperatorAfterExpression(Left, 0);
} }
ExpressionStatement* Parser::parseExpressionStatement() { Expression* Parser::parseExpressionStatement() {
auto E = parseExpression(); auto E = parseExpression();
if (!E) { if (!E) {
skipPastLineFoldEnd(); skipPastLineFoldEnd();
return nullptr; return nullptr;
} }
checkLineFoldEnd(); checkLineFoldEnd();
return new ExpressionStatement(E); return E;
} }
ReturnStatement* Parser::parseReturnStatement() { ReturnExpression* Parser::parseReturnExpression() {
auto Annotations = parseAnnotations(); auto Annotations = parseAnnotations();
auto ReturnKeyword = expectToken<class ReturnKeyword>(); auto ReturnKeyword = expectToken<class ReturnKeyword>();
if (!ReturnKeyword) { if (!ReturnKeyword) {
@ -1094,16 +1106,14 @@ ReturnStatement* Parser::parseReturnStatement() {
Expression = parseExpression(); Expression = parseExpression();
if (!Expression) { if (!Expression) {
ReturnKeyword->unref(); ReturnKeyword->unref();
skipPastLineFoldEnd();
return nullptr; return nullptr;
} }
checkLineFoldEnd();
} }
return new ReturnStatement(Annotations, ReturnKeyword, Expression); return new ReturnExpression(Annotations, ReturnKeyword, Expression);
} }
IfStatement* Parser::parseIfStatement() { IfExpression* Parser::parseIfExpression() {
std::vector<IfStatementPart*> Parts; std::vector<IfExpressionPart*> Parts;
auto Annotations = parseAnnotations(); auto Annotations = parseAnnotations();
auto IfKeyword = expectToken<class IfKeyword>(); auto IfKeyword = expectToken<class IfKeyword>();
if (!IfKeyword) { if (!IfKeyword) {
@ -1136,7 +1146,7 @@ IfStatement* Parser::parseIfStatement() {
} }
} }
Tokens.get()->unref(); // Always a LineFoldEnd Tokens.get()->unref(); // Always a LineFoldEnd
Parts.push_back(new IfStatementPart(Annotations, IfKeyword, Test, T1, Then)); Parts.push_back(new IfExpressionPart(Annotations, IfKeyword, Test, T1, Then));
for (;;) { for (;;) {
auto T3 = peekFirstTokenAfterAnnotationsAndModifiers(); auto T3 = peekFirstTokenAfterAnnotationsAndModifiers();
if (T3->getKind() != NodeKind::ElseKeyword && T3->getKind() != NodeKind::ElifKeyword) { if (T3->getKind() != NodeKind::ElseKeyword && T3->getKind() != NodeKind::ElifKeyword) {
@ -1168,12 +1178,12 @@ IfStatement* Parser::parseIfStatement() {
} }
} }
Tokens.get()->unref(); // Always a LineFoldEnd Tokens.get()->unref(); // Always a LineFoldEnd
Parts.push_back(new IfStatementPart(Annotations, T3, Test, T4, Alt)); Parts.push_back(new IfExpressionPart(Annotations, T3, Test, T4, Alt));
if (T3->getKind() == NodeKind::ElseKeyword) { if (T3->getKind() == NodeKind::ElseKeyword) {
break; break;
} }
} }
return new IfStatement(Parts); return new IfExpression(Parts);
} }
enum class LetMode { enum class LetMode {
@ -1435,7 +1445,7 @@ finish:
Pub, Pub,
Foreign, Foreign,
Let, Let,
Name->as<BindPattern>()->Name, cast<BindPattern>(Name)->Name,
Params, Params,
TA, TA,
Body Body
@ -1448,10 +1458,6 @@ Node* Parser::parseLetBodyElement() {
switch (T0->getKind()) { switch (T0->getKind()) {
case NodeKind::LetKeyword: case NodeKind::LetKeyword:
return parseLetDeclaration(); return parseLetDeclaration();
case NodeKind::ReturnKeyword:
return parseReturnStatement();
case NodeKind::IfKeyword:
return parseIfStatement();
default: default:
return parseExpressionStatement(); return parseExpressionStatement();
} }
@ -1550,7 +1556,7 @@ InstanceDeclaration* Parser::parseInstanceDeclaration() {
std::vector<TypeExpression*> TypeExps; std::vector<TypeExpression*> TypeExps;
for (;;) { for (;;) {
auto T1 = Tokens.peek(); auto T1 = Tokens.peek();
if (T1->is<BlockStart>()) { if (isa<BlockStart>(T1)) {
break; break;
} }
auto TE = parseTypeExpression(); auto TE = parseTypeExpression();
@ -1578,7 +1584,7 @@ InstanceDeclaration* Parser::parseInstanceDeclaration() {
std::vector<Node*> Elements; std::vector<Node*> Elements;
for (;;) { for (;;) {
auto T2 = Tokens.peek(); auto T2 = Tokens.peek();
if (T2->is<BlockEnd>()) { if (isa<BlockEnd>(T2)) {
Tokens.get()->unref(); Tokens.get()->unref();
break; break;
} }
@ -1656,7 +1662,7 @@ ClassDeclaration* Parser::parseClassDeclaration() {
std::vector<Node*> Elements; std::vector<Node*> Elements;
for (;;) { for (;;) {
auto T2 = Tokens.peek(); auto T2 = Tokens.peek();
if (T2->is<BlockEnd>()) { if (isa<BlockEnd>(T2)) {
Tokens.get()->unref(); Tokens.get()->unref();
break; break;
} }
@ -1867,8 +1873,6 @@ Node* Parser::parseSourceElement() {
switch (T0->getKind()) { switch (T0->getKind()) {
case NodeKind::LetKeyword: case NodeKind::LetKeyword:
return parseLetDeclaration(); return parseLetDeclaration();
case NodeKind::IfKeyword:
return parseIfStatement();
case NodeKind::ClassKeyword: case NodeKind::ClassKeyword:
return parseClassDeclaration(); return parseClassDeclaration();
case NodeKind::InstanceKeyword: case NodeKind::InstanceKeyword:
@ -1886,7 +1890,7 @@ SourceFile* Parser::parseSourceFile() {
std::vector<Node*> Elements; std::vector<Node*> Elements;
for (;;) { for (;;) {
auto T0 = Tokens.peek(); auto T0 = Tokens.peek();
if (T0->is<EndOfFile>()) { if (isa<EndOfFile>(T0)) {
break; break;
} }
auto Element = parseSourceElement(); auto Element = parseSourceElement();

View file

@ -49,11 +49,11 @@ void Scope::scan(Node* X) {
} }
void Scope::scanChild(Node* X) { void Scope::scanChild(Node* X) {
if (isa<Expression>(X)) {
return;
}
switch (X->getKind()) { switch (X->getKind()) {
case NodeKind::LetExprBody: case NodeKind::LetExprBody:
case NodeKind::ExpressionStatement:
case NodeKind::IfStatement:
case NodeKind::ReturnStatement:
break; break;
case NodeKind::LetBlockBody: case NodeKind::LetBlockBody:
{ {

View file

@ -135,12 +135,12 @@ int main(int Argc, const char* Argv[]) {
std::multimap<std::size_t, unsigned> Expected; std::multimap<std::size_t, unsigned> Expected;
void visitExpressionAnnotation(ExpressionAnnotation* N) { void visitExpressionAnnotation(ExpressionAnnotation* N) {
if (N->getExpression()->is<CallExpression>()) { if (isa<CallExpression>(N->getExpression())) {
auto CE = static_cast<CallExpression*>(N->getExpression()); auto CE = static_cast<CallExpression*>(N->getExpression());
if (CE->Function->is<ReferenceExpression>()) { if (isa<ReferenceExpression>(CE->Function)) {
auto RE = static_cast<ReferenceExpression*>(CE->Function); auto RE = static_cast<ReferenceExpression*>(CE->Function);
if (RE->getNameAsString() == "expect_diagnostic") { if (RE->getNameAsString() == "expect_diagnostic") {
ZEN_ASSERT(CE->Args.size() == 1 && CE->Args[0]->is<LiteralExpression>()); ZEN_ASSERT(CE->Args.size() == 1 && isa<LiteralExpression>(CE->Args[0]));
Expected.emplace(N->Parent->getStartLine(), static_cast<LiteralExpression*>(CE->Args[0])->getAsInt()); Expected.emplace(N->Parent->getStartLine(), static_cast<LiteralExpression*>(CE->Args[0])->getAsInt());
} }
} }